diff --git a/.gitignore b/.gitignore index e1fa12ea6ad..bdcb067fc26 100644 --- a/.gitignore +++ b/.gitignore @@ -1,15 +1,17 @@ .DS_Store .ipynb_checkpoints node_modules +/.bazelrc +/.tf_configure.bazelrc /bazel-* -/third_party/py/numpy/numpy_include -/tools/bazel.rc +/bazel_pip +/third_party/eigen3/mkl_include +/third_party/mkl/* /tools/python_bin_path.sh /tools/git/gen -/util/python/python_include -/util/python/python_lib /pip_test /_python_build *.pyc __pycache__ *.swp +.vscode/ diff --git a/.mention-bot b/.mention-bot deleted file mode 100644 index 9e4858977f5..00000000000 --- a/.mention-bot +++ /dev/null @@ -1,11 +0,0 @@ -{ - "maxReviewers": 2, - "numFilesToCheck": 10, - "userBlacklist": ["tensorflower-gardener"], - "requiredOrgs": ["tensorflow"], - "skipAlreadyAssignedPR": true, - "skipAlreadyMentionedPR": true, - "skipTitle": "Branch", - "delayed": true, - "delayedUntil": "10m" -} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 36f2f9808e6..43abdaafbf4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -21,9 +21,151 @@ If you have improvements to TensorFlow, send us your pull requests! For those just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/). If you want to contribute but you're not sure where to start, take a look at the -[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/contributions%20welcome). +[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). These are issues that we believe are particularly well suited for outside contributions, often because we probably won't get to them right now. If you decide to start on an issue, leave a comment so that other people know that you're working on it. If you want to help out, but not alone, use the issue comment thread to coordinate. + +### Contribution guidelines and standards + +Before sending your pull request for +[review](https://github.com/tensorflow/tensorflow/pulls), +make sure your changes are consistent with the guidelines and follow the +TensorFlow coding style. + +#### General guidelines and philosophy for contribution + +* Include unit tests when you contribute new features, as they help to + a) prove that your code works correctly, b) guard against future breaking + changes to lower the maintenance cost. +* Bug fixes also generally require unit tests, because the presence of bugs + usually indicates insufficient test coverage. +* Keep API compatibility in mind when you change code in core TensorFlow, + e.g., code in [tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core) and [tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python). + TensorFlow has reached version 1 and hence cannot make + non-backward-compatible API changes without a major release. Reviewers of your + pull request will comment on any API compatibility issues. +* When you contribute a new feature to TensorFlow, the maintenance burden is (by + default) transferred to the TensorFlow team. This means that benefit of + contribution must be compared against the cost of maintaining the feature. +* Full new features (e.g., a new op implementing a cutting-edge algorithm) + typically will live in + [tensorflow/contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib) + to get some airtime before decision is made regarding whether they are to be + migrated to the core. + +#### License + +Include a license at the top of new files. + +* [C/C++ license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op.cc#L1) +* [Python license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn.py#L1) +* [Java license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/Graph.java#L1) +* [Go license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/operation.go#L1) +* [Bash license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ci_build/ci_sanity.sh#L2) +* [HTML license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/dist/index.html#L2) +* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/components/tf_backend/backend.ts#L1) + +Bazel BUILD files also need to include a license section, e.g., +[BUILD example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/BUILD#L61). + +#### C++ coding style + +Changes to TensorFlow C++ code should conform to +[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). + +Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do: + +```bash +apt-get install -y clang-tidy +``` + +You can check a C/C++ file by doing: + + +```bash +clang-format --style=google > /tmp/my_cc_file.cc +diff /tmp/my_cc_file.cc +``` + +#### Python coding style + +Changes to TensorFlow Python code should conform to +[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) + +Use `pylint` to check your Python changes. To install `pylint` and +retrieve TensorFlow's custom style definition: + +```bash +pip install pylint +wget -O /tmp/pylintrc https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc +``` + +To check a file with `pylint`: + +```bash +pylint --rcfile=/tmp/pylintrc myfile.py +``` + +#### Coding style for other languages + +* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) +* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) +* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) + +#### Running sanity check + +If you have Docker installed on your system, you can perform a sanity check on +your changes by running the command: + +```bash +tensorflow/tools/ci_build/ci_build.sh CPU tensorflow/tools/ci_build/ci_sanity.sh +``` + +This will catch most license, Python coding style and BUILD file issues that +may exist in your changes. + +#### Running unit tests + +There are two ways to run TensorFlow unit tests. + +1. Using tools and libraries installed directly on your system. + + Refer to the + [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) and + [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) + for the required packages. Alternatively, use the said + [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., + `tensorflow/tensorflow:nightly-devel` and `tensorflow/tensorflow:nightly-devel-gpu` + for development to avoid installing the packages directly on your system. + + Once you have the packages installed, you can run a specific unit test in + bazel by doing as follows: + + If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add + the `cuda` option flag + + ```bash + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + + export flags="--config=opt --config=cuda -k" + ``` + + For example, to run all tests under tensorflow/python, do: + + ```bash + bazel test ${flags} //tensorflow/python/... + ``` + +2. Using [Docker](www.docker.com) and TensorFlow's CI scripts. + + ```bash + # Install Docker first, then this will build and run cpu tests + tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... + ``` + + See + [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) for details. + diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index af76188c2f4..5b37028c509 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -1,36 +1,37 @@ -NOTE: Only file GitHub issues for bugs and feature requests. All other topics will be closed. +Please go to Stack Overflow for help and support: -For general support from the community, see [StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow). -To make bugs and feature requests more easy to find and organize, we close issues that are deemed -out of scope for GitHub Issues and point people to StackOverflow. +http://stackoverflow.com/questions/tagged/tensorflow -For bugs or installation issues, please provide the following information. -The more information you provide, the more easily we will be able to offer -help and advice. +If you open a GitHub issue, here is our policy: -### What related GitHub issues or StackOverflow threads have you found by searching the web for your problem? +1. It must be a bug or a feature request. +2. The form below must be filled out. +3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorflow/issues). -### Environment info -Operating System: +**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow. -Installed version of CUDA and cuDNN: -(please attach the output of `ls -l /path/to/cuda/lib/libcud*`): +------------------------ -If installed from binary pip package, provide: +### System information +- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**: +- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**: +- **TensorFlow installed from (source or binary)**: +- **TensorFlow version (use command below)**: +- **Bazel version (if compiling from source)**: +- **CUDA/cuDNN version**: +- **GPU model and memory**: +- **Exact command to reproduce**: -1. A link to the pip package you installed: -2. The output from `python -c "import tensorflow; print(tensorflow.__version__)"`. +You can collect some of this information using our environment capture script: -If installed from source, provide +https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh -1. The commit hash (`git rev-parse HEAD`) -2. The output of `bazel version` +You can obtain the TensorFlow version with -### If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code) +python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" +### Describe the problem +Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request. -### What other attempted solutions have you tried? - - -### Logs or other output that would be helpful -(If logs are large, please upload as attachment or provide link). +### Source code / logs +Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem. diff --git a/README.md b/README.md index 40e8a4b190c..cbc94c1ab2b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@


+ ----------------- | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | @@ -25,19 +26,20 @@ guidelines](CONTRIBUTING.md).** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for tracking requests and bugs, but please see -[Community](tensorflow/g3doc/resources/index.md#community) for general questions +[Community](https://www.tensorflow.org/community/) for general questions and discussion.** ## Installation -*See [Download and Setup](tensorflow/g3doc/get_started/os_setup.md) for instructions on how to install our release binaries or how to build from source.* +*See [Installing TensorFlow](https://www.tensorflow.org/install/) for instructions on how to install our release binaries or how to build from source.* People who are a little more adventurous can also try our nightly binaries: -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) -* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) -* [Android](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/lastSuccessfulBuild/artifact/bazel-out/local_linux/bin/tensorflow/examples/android/tensorflow_demo.apk) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-android/TF_BUILD_CONTAINER_TYPE=ANDROID,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=NO_PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=android-slave/)) +* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) +* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) +* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) +* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) +* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/)) +* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/)) * Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) @@ -50,7 +52,7 @@ $ python >>> hello = tf.constant('Hello, TensorFlow!') >>> sess = tf.Session() >>> sess.run(hello) -Hello, TensorFlow! +'Hello, TensorFlow!' >>> a = tf.constant(10) >>> b = tf.constant(32) >>> sess.run(a+b) @@ -58,11 +60,11 @@ Hello, TensorFlow! >>> ``` -##For more information +## For more information -* [TensorFlow website](http://tensorflow.org) +* [TensorFlow website](https://tensorflow.org) * [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) -The TensorFlow community has created amazing things with TensorFlow, please see the [resources section of tensorflow.org](https://www.tensorflow.org/versions/master/resources#community) for an incomplete list. +The TensorFlow community has created amazing things with TensorFlow, please see the [resources section of tensorflow.org](https://www.tensorflow.org/about/#community) for an incomplete list. diff --git a/RELEASE.md b/RELEASE.md index ab3ecbd7746..9875838d7e1 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,312 @@ +# Release 1.2.0 + +## Major Features and Improvements +* Python 3.6 support on Windows. +* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution. +* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times. +* Added libverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo). +* Bring `tf.feature_column.*` into the API. Non-deprecated functionality from `tf.contrib.layers.*` is moved to `tf.feature_column.*` with cosmetic changes. +* `RNNCell` objects now subclass `tf.layers.Layer`. The strictness described + in the TensorFlow 1.1 release is gone: The first time an RNNCell is used, + it caches its scope. All future uses of the RNNCell will reuse variables from + that same scope. This is a breaking change from the behavior of RNNCells + in TensorFlow versions <= 1.0.1. TensorFlow 1.1 had checks in place to + ensure old code works correctly with the new semantics; this version + allows more flexible uses of RNNCell but can lead to subtle errors if + using code meant for TensorFlow <= 1.0.1. For example, writing: + `MultiRNNCell([lstm] * 5)` will now build a 5-layer LSTM stack where each + layer shares the **same** parameters. To get 5 layers each with their own + parameters, write: `MultiRNNCell([LSTMCell(...) for _ in range(5)])`. + If at all unsure, first test your code with TF 1.1; ensure it raises no + errors, and then upgrade to TF 1.2. +* RNNCells' variable names have been renamed for consistency with Keras layers. + Specifically, the previous variable names "weights" and "biases" have + been changed to "kernel" and "bias", respectively. + This may cause backward incompatibility with regard to your old + checkpoints containing such RNN cells, in which case you can use the tool + [checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py) + to convert the variable names in your old checkpoints. +* Many of the RNN functions and classes that were in the `tf.nn` namespace + before the 1.0 release and which were moved to `tf.contrib.rnn` have now + been moved back to the core namespace. This includes + `RNNCell`, `LSTMCell`, `GRUCell`, and a number of other cells. These + now reside in `tf.nn.rnn_cell` (with aliases in `tf.contrib.rnn` for backwards + compatibility). The original `tf.nn.rnn` function is now `tf.nn.static_rnn`, + and the bidirectional static and state saving static rnn functions are also + now back in the `tf.nn` namespace. + + Notable exceptions are the `EmbeddingWrapper`, `InputProjectionWrapper` and + `OutputProjectionWrapper`, which will slowly be moved to deprecation + in `tf.contrib.rnn`. These are inefficient wrappers that should often + be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post- + processing of the rnn. For RNN decoding, this functionality has been replaced + with an alternative API in `tf.contrib.seq2seq`. +* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of + optimized deep learning primitives: In addition to matrix multiplication and + convolution, these building blocks include: + Direct batched convolution + Pooling: maximum, minimum, average + Normalization: LRN, batch normalization + Activation: rectified linear unit (ReLU) + Data manipulation: multi-dimensional transposition (conversion), split, + concat, sum and scale. +* TensorForest Estimator now supports SavedModel export for serving. +* Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters. +* TensorFlow C library now available for Windows. +* We released a new open-source version of TensorBoard. +* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/programmers_guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel +* Android releases of TensorFlow are now pushed to jcenter for easier + integration into apps. See + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/android/README.md + for more details. +* RNNCells' variable names have been renamed for consistency with Keras layers. + Specifically, the previous variable names "weights" and "biases" have + been changed to "kernel" and "bias", respectively. + This may cause backward incompatibility with regard to your old + checkpoints containing such RNN cells, in which case you can use the tool + [checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py) + to convert the variable names in your old checkpoints. +* Many of the RNN functions and classes that were in the `tf.nn` namespace + before the 1.0 release and which were moved to `tf.contrib.rnn` have now + been moved back to the core namespace. This includes + `RNNCell`, `LSTMCell`, `GRUCell`, and a number of other cells. These + now reside in `tf.nn.rnn_cell` (with aliases in `tf.contrib.rnn` for backwards + compatibility). The original `tf.nn.rnn` function is now `tf.nn.static_rnn`, + and the bidirectional static and state saving static rnn functions are also + now back in the `tf.nn` namespace. + + Notable exceptions are the `EmbeddingWrapper`, `InputProjectionWrapper` and + `OutputProjectionWrapper`, which will slowly be moved to deprecation + in `tf.contrib.rnn`. These are inefficient wrappers that should often + be replaced by calling `embedding_lookup` or `layers.dense` as pre- or post- + processing of the rnn. For RNN decoding, this functionality has been replaced + with an alternative API in `tf.contrib.seq2seq`. +* Intel MKL Integration (https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture). Intel developed a number of + optimized deep learning primitives: In addition to matrix multiplication and + convolution, these building blocks include: + Direct batched convolution + Pooling: maximum, minimum, average + Normalization: LRN, batch normalization + Activation: rectified linear unit (ReLU) + Data manipulation: multi-dimensional transposition (conversion), split, + concat, sum and scale. + +## Deprecations + +* TensorFlow 1.2 may be the last time we build with cuDNN 5.1. Starting with + TensorFlow 1.3, we will try to build all our prebuilt binaries with cuDNN 6.0. + While we will try to keep our source code compatible with cuDNN 5.1, it will + be best effort. + +## Breaking Changes to the API +* `org.tensorflow.contrib.android.TensorFlowInferenceInterface` now throws exceptions where possible and has simplified method signatures. + +## Changes to contrib APIs +* Added `tf.contrib.util.create_example`. +* Added bilinear interpolation to `tf.contrib.image`. +* Add `tf.contrib.stateless` for random ops with custom seed control. +* MultivariateNormalFullCovariance added to contrib/distributions/ +* tensorflow/contrib/rnn undergoes RNN cell variable renaming for + consistency with Keras layers. Specifically, the previous variable names + "weights" and "biases" are changed to "kernel" and "bias", respectively. + This may cause backward incompatibility with regard to your old + checkpoints containing such RNN cells, in which case you can use the + [checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py) + to convert the variable names in your old checkpoints. +* Added `tf.contrib.kernel_methods` module with Ops and estimators for primal + (explicit) kernel methods in TensorFlow. + +## Bug Fixes and Other Changes +* In python, `Operation.get_attr` on type attributes returns the Python DType + version of the type to match expected get_attr documentation rather than the + protobuf enum. +* tensorflow/contrib/rnn undergoes RNN cell variable renaming for + consistency with Keras layers. Specifically, the previous variable names + "weights" and "biases" are changed to "kernel" and "bias", respectively. +* Changed MIN_SDK version to 8.0 when building iOS libraries. +* Fixed LIBXSMM integration. +* Make decode_jpeg/decode_png/decode_gif handle all formats, since users frequently try to decode an image as the wrong type. +* Improve implicit broadcasting lowering. +* Improving stability of GCS/Bigquery clients by a faster retrying of stale transmissions. +* Remove OpKernelConstruction::op_def() as part of minimizing proto dependencies. +* VectorLaplaceDiag distribution added. +* Android demo no longer requires libtensorflow_demo.so to run (libtensorflow_inference.so still required) +* Added `categorical_column_with_vocabulary_file`. +* Introduce ops for batching/unbatching tensors across Session::Run() calls. +* Add tf.log_sigmoid(x) = tf.log(tf.sigmoid(x)) = -tf.nn.softplus(-x). +* Changed hooks lists to immutable tuples, and now allow any iterable for the associated arguments. +* Introduce TFDecorator. +* Added an Mfcc op for speech feature generation. +* Improved DirectSession::Run() overhead and error checking. Feeding a value of the wrong type will now synchronously raise an INVALID_ARGUMENT error instead of asynchronously raising an INTERNAL error. Code that depends on the (undefined) behavior when feeding a tensor of the wrong type may need to be updated. +* Added unreduced NONE, and reduced MEAN options for losses. Removed "WEIGHTED_" prefix from other Reduction constants. +* assertAllClose now handles dicts. +* Added Gmock matcher for HloInstructions. +* Add var name to errors on variable restore. +* Added an AudioSpectrogram op for audio feature generation. +* Added `reduction` arg to losses. +* `tf.placeholder` can represent scalar shapes and partially known. +* Remove estimator_spec(mode) argument. +* Added an AudioSpectrogram op for audio feature generation. +* TensorBoard disables all runs by default if there are more than 40 runs. +* Removed old doc generator code. +* GCS file system integration now supports domain buckets, e.g gs://bucket.domain.com/path. +* Add `tf.summary.text` for outputting text to TensorBoard. +* The "run" command of tfdbg's command-line interface now supports filtering of tensors by node name, op type and tensor dtype. +* `tf.string_to_number` now supports int64 and float64 outputs. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4F2E4A2E, Aaron Schumacher, Abhi Agg, admcrae, Adriano Carmezim, Adrià Arrufat, +agramesh1, Akimitsu Seo, Alan Mosca, Alex Egg, Alex Rothberg, Alexander Heinecke, +Alexander Matyasko, Alexandr Baranezky, Alexandre Caulier, Ali Siddiqui, Anand Venkat, +Andrew Hundt, Androbin, Anmol Sharma, Arie, Arno Leist, Arron Cao, AuréLien Geron, Bairen Yi, +Beomsu Kim, Carl Thomé, cfperez, Changming Sun, Corey Wharton, critiqjo, Dalei Li, Daniel +Rasmussen, Daniel Trebbien, DaríO Hereñú, David Eng, David Norman, David Y. Zhang, Davy Song, ddurham2, +Deepak Subburam, Dmytro Kyrychuk, Dominic Rossi, Dominik SchlöSser, Dustin Tran, +Eduardo Pinho, Egil Martinsson, Elliot Saba, Eric Bigelow, Erik Smistad, Evan Klitzke, +Fabrizio Milo, Falcon Dai, Fei Gao, FloopCZ, Fung Lam, Gautam, GBLin5566, Greg Peatfield, +Gu Wang, Guenther Schmuelling, Hans Pabst, Harun Gunaydin, Huaizheng, Ido Shamay, Ikaro +Silva, Ilya Edrenkin, Immexxx, James Mishra, Jamie Cooke, Jay Young, Jayaram Bobba, +Jianfei Wang, jinghua2, Joey Meyer, John Maidens, Jonghoon Jin, Julian Villella, +Jun Kim, Jun Shi, Junwei Pan, jyegerlehner, Karan Desai, Karel Van De Plassche, +Kb Sriram, KhabarlakKonstantin, Koan-Sin Tan, krivard, Kwotsin, Leandro Gracia Gil, +Li Chen, Liangliang He, Louie Helm, lspvic, Luiz Henrique Soares, LáSzló Csomor, +Mark Wong, Mathew Wicks, Matthew Rahtz, Maxwell Paul Brickner, Michael Hofmann, Miguel +Flores Ruiz De Eguino, MikeTam1021, Mortada Mehyar, Mycosynth, Namnamseo, +Nate Harada, Neven Miculinic, Nghia Tran, Nick Lyu, Niranjan Hasabnis, Nishidha, Oleksii +Kuchaiev, Oyesh Mann Singh, Panmari, Patrick, Paul Van Eck, Piyush Chaudhary, Quim Llimona, +Raingo, Richard Davies, Ruben Vereecken, Sahit Chintalapudi, Sam Abrahams, Santiago Castro, +Scott Sievert, Sean O'Keefe, Sebastian Schlecht, Shane, Shubhankar Deshpande, Spencer Schaber, +Sunyeop Lee, t13m, td2014, Thomas H. P. Andersen, Toby Petty, Umang Mehta, +Vadim Markovtsev, Valentin Iovene, Vincent Zhao, Vit Stepanovs, Vivek Rane, Vu Pham, wannabesrevenge, +weipingpku, wuhaixutab, wydwww, Xiang Gao, Xiaolin Lin, xiaoyaozhuzi, Yaroslav Bulatov, Yi Liu, +Yoshihiro Sugi, Yuan (Terry) Tang, Yuming Wang, Yuxin Wu, Zader Zheng, Zhaojun Zhang, zhengjiajin, +ZhipengShen, Ziming Dong, zjj2wry + +We are also grateful to all who filed issues or helped resolve them, asked and +answered questions, and were part of inspiring discussions. + +# Release 1.1.0 + +## Major Features and Improvements +* Added Java API support for Windows. +* Added `tf.spectral` module. Moved existing FFT ops to `tf.spectral` while + keeping an alias in the old location (`tf.*`). +* Added 1D, 2D and 3D Fourier transform ops for real signals to `tf.spectral`. +* Added a `tf.bincount` function. +* Added Keras 2 API to contrib. +* Added a new lightweight queue-like object - `RecordInput`. +* Added `tf.contrib.image.compose_transforms` function. +* Bring `tf.estimator.*` into the API. Non-deprecated functionality from `tf.contrib.learn.Estimator` is moved to `tf.estimator.Estimator` with cosmetic changes. +* Docker images: TF images on gcr.io and Docker Hub are upgraded to ubuntu:16.04. +* Added the following features to TensorFlow Debugger (tfdbg): + * Ability to inspect Python source file against TF ops and tensors (command `print_source` / `ps`) + * New navigation bar in Curses-based UI + * NodeStepper (command `invoke_stepper`) now uses intermediate tensor dumps. It also uses `TensorHandles` as direct feeds during successive `cont` calls for improved performance and reduced memory consumption. +* Initial release of installation guides for Java, C, and Go. +* Added Text Dashboard to TensorBoard. + +## Deprecations + +* TensorFlow 1.1.0 will be the last time we release a binary with Mac GPU support. Going forward, we will stop testing on Mac GPU systems. We continue to welcome patches that maintain Mac GPU support, and we will try to keep the Mac GPU build working. + +## Changes to contrib APIs +* The behavior of RNNCells is now stricter due to the transition towards making RNNCells act more like Keras layers. + * If an RNNCell is used twice in two different variable scopes, an error is raised describing how to avoid this behavior. + * If an RNNCell is used in a variable scope with existing conflicting variables, an error is raised showing that the RNNCell must be constructed with argument `reuse=True`. +* Deprecated contrib/distributions `pmf`, `pdf`, `log_pmf`, `log_pdf`. +* Moved `bayesflow.special_math` to distributions. +* `tf.contrib.tensor_forest.python.tensor_forest.RandomForestDeviceAssigner` removed. +* Changed some MVN classes and parameters: + * `tf.contrib.distributions.MultivariateNormalFull` replaced by `tf.contrib.distributions.MultivariateNormalTriL`. + * `tf.contrib.distributions.MultivariateNormalCholesky` replaced by `tf.contrib.distributions.MultivariateNormalTriL` + * `tf.contrib.distributions.MultivariateNormalDiagWithSoftplusStDev` replaced + by `tf.contrib.distributions.MultivariateNormalDiagWithSoftplusScale` + * `tf.contrib.distributions.MultivariateNormalDiag` arguments changed from `mu`, `diag_stddev` to `log`, `scale_diag`. + * `tf.contrib.distributions.MultivariateNormalDiagPlusVDVT` removed. + * `tf.contrib.distributions.MultivariateNormalDiagPlusLowRank` added. + +## Bug Fixes and Other Changes +* Java: Support for loading models exported using the SavedModel API (courtesy @EronWright). +* Go: Added support for incremental graph execution. +* Fix a bug in the WALS solver when single-threaded. +* Added support for integer sparse feature values in `tf.contrib.layers.sparse_column_with_keys`. +* Fixed `tf.set_random_seed(0)` to be deterministic for all ops. +* Stability improvements for the GCS file system support. +* Improved TensorForest performance. +* Added support for multiple filename globs in `tf.matching_files`. +* `LogMessage` now includes a timestamp as beginning of a message. +* Added MultiBox person detector example standalone binary. +* Android demo: Makefile build functionality added to build.gradle to fully support building TensorFlow demo in Android on Windows. +* Android demo: read MultiBox priors from txt file rather than protobuf. +* Added colocation constraints to `StagingArea`. +* `sparse_matmul_op` reenabled for Android builds. +* Restrict weights rank to be the same as the broadcast target, to avoid ambiguity on broadcast rules. +* Upgraded libxsmm to 1.7.1 and applied other changes for performance and memory usage. +* Fixed bfloat16 integration of LIBXSMM sparse mat-mul. +* Improved performance and reduce memory usage by allowing ops to forward input buffers to output buffers and perform computations in-place. +* Improved the performance of CPU assignment for strings. +* Speed up matrix * vector multiplication and matrix * matrix with unknown shapes. +* C API: Graph imports now support input remapping, control dependencies, and returning imported nodes (see `TF_GraphImportGraphDefWithReturnOutputs()`) +* Multiple C++ API updates. +* Multiple TensorBoard updates including: + * Users can now view image summaries at various sampled steps (instead of just the last step). + * Bugs involving switching runs as well as the image dashboard are fixed. + * Removed data download links from TensorBoard. + * TensorBoard uses a relative data directory, for easier embedding. + * TensorBoard automatically ignores outliers for domain calculation, and formats proportional values consistently. +* Multiple tfdbg bug fixes: + * Fixed Windows compatibility issues. + * Command history now persists across runs. + * Bug fix in graph validation related to `tf.while_loops`. +* Java Maven fixes for bugs with Windows installation. +* Backport fixes and improvements from external keras. +* Keras config file handling fix. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +A. Besir Kurtulmus, Adal Chiriliuc, @akash, Alec-Desouza, Alex Rothberg, Alex +Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton +Loss, @Aravind, @Arie, Ashutosh Das, AuréLien Geron, Bairen Yi, @bakunyo, Ben +Visser, Brady Zhou, Calpa Liu, Changming Sun, Chih Cheng Liang, Christopher +Berner, Clark Zinzow, @Conchylicultor, Dan Ellis, Dan J, Dan Jarvis, Daniel +Ylitalo, Darren Garvey, David Norman, David Truong, @DavidNorman, Dimitar +Pavlov, Dmitry Persiyanov, @Eddie, @elirex, Erfan Noury, Eron Wright, Evgeny +Mazovetskiy, Fabrizio (Misto) Milo, @fanlu, Fisher Coder, Florian Courtial, +Franck Dernoncourt, Gagan Goel, Gao, Xiang, @Gautam, Gefu Tang, @guilherme, +@guschmue, Hannah Provenza, Hans Pabst, @hartb, Hsiao Yi, Huazuo Gao, Igor +ChorążEwicz, Ivan Smirnov, Jakub Kolodziejczyk, Jason Gavris, Jason Morton, Jay +Young, Jayaram Bobba, Jeremy Sawruk, Jiaming Liu, Jihun Choi, @jiqiu, Joan Thibault, +John C F, Jojy George Varghese, Jon Malmaud, Julian Berman, Julian Niedermeier, +Junpeng Lao, Kai Sasaki, @Kankroc, Karl Lessard, Kyle Bostelmann, @Lezcano, Li +Yi, Luo Yun, @lurker, Mahmoud-Abuzaina, Mandeep Singh, Marek Kolodziej, Mark +Szepieniec, Martial Hue, Medhat Omr, Memo Akten, Michael Gharbi, MichaëL Defferrard, +Milan Straka, @MircoT, @mlucool, Muammar Ibn Faisal, Nayana Thorat, @nghiattran, +Nicholas Connor, Nikolaas Steenbergen, Niraj Patel, Niranjan Hasabnis, @Panmari, +Pavel Bulanov, Philip Pries Henningsen, Philipp Jund, @polonez, Prayag Verma, Rahul +Kavi, Raphael Gontijo Lopes, @rasbt, Raven Iqqe, Reid Pryzant, Richard Shin, Rizwan +Asif, Russell Kaplan, Ryo Asakura, RüDiger Busche, Saisai Shao, Sam Abrahams, @sanosay, +Sean Papay, @seaotterman, @selay01, Shaurya Sharma, Sriram Narayanamoorthy, Stefano +Probst, @taknevski, @tbonza, @teldridge11, Tim Anglade, Tomas Reimers, Tomer Gafner, +Valentin Iovene, Vamsi Sripathi, Viktor Malyi, Vit Stepanovs, Vivek Rane, Vlad Firoiu, +@wangg12, @will, Xiaoyu Tao, Yaroslav Bulatov, Yi Liu, Yuan (Terry) Tang, @Yufeng, +Yuming Wang, Yuxin Wu, Zafar Takhirov, Ziming Dong + +We are also grateful to all who filed issues or helped resolve them, asked and +answered questions, and were part of inspiring discussions. + + +# Release 1.0.1 + +## Bug Fixes and Other Changes +* Change GraphConstructor to not increase the version when importing, but instead take the min of all versions. +* Google Cloud Storage fixes. +* Removed `tf.core` and `tf.python` modules from the API. These were never intended to be exposed. Please use the same objects through top-level `tf` module instead. + # Release 1.0.0 ## Major Features and Improvements @@ -87,8 +396,12 @@ To help you upgrade your existing TensorFlow Python code to match the API change * In the C++ API (in tensorflow/cc), Input, Output, etc. have moved from the tensorflow::ops namespace to tensorflow. * Change arg order for `{softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits` to be (labels, predictions), and force use of named args. +* tf.nn.rnn_cell.* and most functions in tf.nn.rnn.* (with the exception of dynamic_rnn and raw_rnn) are temporarily in tf.contrib.rnn. They will be moved back into core for TF 1.2. +* `tf.nn.sampled_softmax_loss` and `tf.nn.nce_loss` have both changed their API such that you need to switch the `inputs, labels` to `labels, inputs` parameters. +* The shape keyword argument of the `SparseTensor` constructor changes its name to `dense_shape` between Tensorflow 0.12 and Tensorflow 1.0. ## Bug Fixes and Other Changes +* Numerous C++ API updates. * New op: `parallel_stack`. * Introducing common tf io compression options constants for RecordReader/RecordWriter. @@ -127,6 +440,7 @@ To help you upgrade your existing TensorFlow Python code to match the API change * `tf.divide` now honors the name field. * Make metrics weight broadcasting more strict. * Add new queue-like `StagingArea` and new ops: `stage` and `unstage`. +* Enable inplace update ops for strings on CPU. Speed up string concat. ## Thanks to our Contributors @@ -193,7 +507,7 @@ answered questions, and were part of inspiring discussions. indexing now starts from 1 instead of 0, and `bus_id==0` is used where previously `BUS_ANY` was used. * `Env::FileExists` and `FileSystem::FileExists` now return a tensorflow::Status - intead of a bool. Any callers to this function can be converted to a bool + instead of a bool. Any callers to this function can be converted to a bool by adding .ok() to the call. * The C API type `TF_SessionWithGraph` has been renamed to `TF_Session`, indicating its preferred use in language bindings for TensorFlow. @@ -212,7 +526,7 @@ answered questions, and were part of inspiring discussions. * `SparseTensor.shape` has been renamed to `SparseTensor.dense_shape`. Same for `SparseTensorValue.shape`. * `Env::FileExists` and `FileSystem::FileExists` now return a - `tensorflow::Status` intead of a bool. Any callers to this function can be + `tensorflow::Status` instead of a bool. Any callers to this function can be converted to a bool by adding `.ok()` to the call. * C API: Type `TF_SessionWithGraph` has been renamed to `TF_Session`, indicating its preferred use in language bindings for TensorFlow. What was previously diff --git a/WORKSPACE b/WORKSPACE index 958a53c30ed..74ce13f4e88 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "60fc6977908f999b23ca65698c2bb70213403824a84f7904310b6000d78be9ce", - strip_prefix = "rules_closure-5ca1dab6df9ad02050f7ba4e816407f88690cf7d", + sha256 = "bc41b80486413aaa551860fc37471dbc0666e1dbb5236fb6177cb83b0c105846", + strip_prefix = "rules_closure-dec425a4ff3faf09a56c85d082e4eed05d8ce38f", urls = [ - "http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", # 2017-02-03 - "https://github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", + "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dec425a4ff3faf09a56c85d082e4eed05d8ce38f.tar.gz", # 2017-06-02 + "https://github.com/bazelbuild/rules_closure/archive/dec425a4ff3faf09a56c85d082e4eed05d8ce38f.tar.gz", ], ) @@ -14,510 +14,56 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() -load("//tensorflow:workspace.bzl", "check_version", "tf_workspace") - -# We must check the bazel version before trying to parse any other BUILD files, -# in case the parsing of those build files depends on the bazel version we -# require here. -check_version("0.4.2") +load("//tensorflow:workspace.bzl", "tf_workspace") # Uncomment and update the paths in these entries to build the Android demo. #android_sdk_repository( # name = "androidsdk", # api_level = 23, -# build_tools_version = "23.0.1", +# # Ensure that you have the build_tools_version below installed in the +# # SDK manager as it updates periodically. +# build_tools_version = "25.0.2", # # Replace with path to Android SDK on your system # path = "", #) # +# Android NDK r12b is recommended (higher may cause issues with Bazel) #android_ndk_repository( # name="androidndk", # path="", -# api_level=21) +# # This needs to be 14 or higher to compile TensorFlow. +# # Note that the NDK version is not the API level. +# api_level=14) # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() new_http_archive( - name = "inception5h", - build_file = "models.BUILD", - url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip", - sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364" + name = "inception5h", + build_file = "models.BUILD", + sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364", + urls = [ + "http://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip", + "http://download.tensorflow.org/models/inception5h.zip", + ], ) new_http_archive( - name = "mobile_multibox", - build_file = "models.BUILD", - url = "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", - sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96" + name = "mobile_multibox", + build_file = "models.BUILD", + sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", + urls = [ + "http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", + "http://download.tensorflow.org/models/mobile_multibox_v1a.zip", + ], ) new_http_archive( - name = "stylize", - build_file = "models.BUILD", - url = "https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", - sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa" -) - -# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT - -new_http_archive( - name = "d3", - build_file = "bower.BUILD", - url = "https://github.com/mbostock-bower/d3-bower/archive/v3.5.15.tar.gz", - strip_prefix = "d3-bower-3.5.15", -) - -new_http_archive( - name = "dagre", - build_file = "bower.BUILD", - url = "https://github.com/cpettitt/dagre/archive/v0.7.4.tar.gz", - strip_prefix = "dagre-0.7.4", -) - -new_http_archive( - name = "es6_promise", - build_file = "bower.BUILD", - url = "https://github.com/components/es6-promise/archive/v2.1.0.tar.gz", - strip_prefix = "es6-promise-2.1.0", -) - -new_http_archive( - name = "font_roboto", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/font-roboto/archive/v1.0.1.tar.gz", - strip_prefix = "font-roboto-1.0.1", -) - -new_http_archive( - name = "graphlib", - build_file = "bower.BUILD", - url = "https://github.com/cpettitt/graphlib/archive/v1.0.7.tar.gz", - strip_prefix = "graphlib-1.0.7", -) - -new_http_archive( - name = "iron_a11y_announcer", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-a11y-announcer/archive/v1.0.5.tar.gz", - strip_prefix = "iron-a11y-announcer-1.0.5", -) - -new_http_archive( - name = "iron_a11y_keys_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz", - strip_prefix = "iron-a11y-keys-behavior-1.1.8", -) - -new_http_archive( - name = "iron_ajax", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-ajax/archive/v1.2.0.tar.gz", - strip_prefix = "iron-ajax-1.2.0", -) - -new_http_archive( - name = "iron_autogrow_textarea", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-autogrow-textarea/archive/v1.0.12.tar.gz", - strip_prefix = "iron-autogrow-textarea-1.0.12", -) - -new_http_archive( - name = "iron_behaviors", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-behaviors/archive/v1.0.17.tar.gz", - strip_prefix = "iron-behaviors-1.0.17", -) - -new_http_archive( - name = "iron_checked_element_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-checked-element-behavior/archive/v1.0.4.tar.gz", - strip_prefix = "iron-checked-element-behavior-1.0.4", -) - -new_http_archive( - name = "iron_collapse", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-collapse/archive/v1.0.8.tar.gz", - strip_prefix = "iron-collapse-1.0.8", -) - -new_http_archive( - name = "iron_dropdown", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-dropdown/archive/v1.4.0.tar.gz", - strip_prefix = "iron-dropdown-1.4.0", -) - -new_http_archive( - name = "iron_fit_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-fit-behavior/archive/v1.2.5.tar.gz", - strip_prefix = "iron-fit-behavior-1.2.5", -) - -new_http_archive( - name = "iron_flex_layout", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-flex-layout/archive/v1.3.0.tar.gz", - strip_prefix = "iron-flex-layout-1.3.0", -) - -new_http_archive( - name = "iron_form_element_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-form-element-behavior/archive/v1.0.6.tar.gz", - strip_prefix = "iron-form-element-behavior-1.0.6", -) - -new_http_archive( - name = "iron_icon", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-icon/archive/v1.0.11.tar.gz", - strip_prefix = "iron-icon-1.0.11", -) - -new_http_archive( - name = "iron_icons", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-icons/archive/v1.1.3.tar.gz", - strip_prefix = "iron-icons-1.1.3", -) - -new_http_archive( - name = "iron_iconset_svg", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-iconset-svg/archive/v1.1.0.tar.gz", - strip_prefix = "iron-iconset-svg-1.1.0", -) - -new_http_archive( - name = "iron_input", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-input/archive/1.0.10.tar.gz", - strip_prefix = "iron-input-1.0.10", -) - -new_http_archive( - name = "iron_list", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-list/archive/v1.3.9.tar.gz", - strip_prefix = "iron-list-1.3.9", -) - -new_http_archive( - name = "iron_menu_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-menu-behavior/archive/v1.1.10.tar.gz", - strip_prefix = "iron-menu-behavior-1.1.10", -) - -new_http_archive( - name = "iron_meta", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-meta/archive/v1.1.1.tar.gz", - strip_prefix = "iron-meta-1.1.1", -) - -new_http_archive( - name = "iron_overlay_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-overlay-behavior/archive/v1.10.1.tar.gz", - strip_prefix = "iron-overlay-behavior-1.10.1", -) - -new_http_archive( - name = "iron_range_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-range-behavior/archive/v1.0.4.tar.gz", - strip_prefix = "iron-range-behavior-1.0.4", -) - -new_http_archive( - name = "iron_resizable_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-resizable-behavior/archive/v1.0.3.tar.gz", - strip_prefix = "iron-resizable-behavior-1.0.3", -) - -new_http_archive( - name = "iron_scroll_target_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz", - strip_prefix = "iron-scroll-target-behavior-1.0.3", -) - -new_http_archive( - name = "iron_selector", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-selector/archive/v1.5.2.tar.gz", - strip_prefix = "iron-selector-1.5.2", -) - -new_http_archive( - name = "iron_validatable_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-validatable-behavior/archive/v1.1.1.tar.gz", - strip_prefix = "iron-validatable-behavior-1.1.1", -) - -new_http_archive( - name = "lodash", - build_file = "bower.BUILD", - url = "https://github.com/lodash/lodash/archive/3.8.0.tar.gz", - strip_prefix = "lodash-3.8.0", -) - -new_http_archive( - name = "neon_animation", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/neon-animation/archive/v1.2.2.tar.gz", - strip_prefix = "neon-animation-1.2.2", -) - -http_file( - name = "numericjs_numeric_min_js", - url = "https://cdnjs.cloudflare.com/ajax/libs/numeric/1.2.6/numeric.min.js", -) - -new_http_archive( - name = "paper_behaviors", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-behaviors/archive/v1.0.12.tar.gz", - strip_prefix = "paper-behaviors-1.0.12", -) - -new_http_archive( - name = "paper_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-button/archive/v1.0.11.tar.gz", - strip_prefix = "paper-button-1.0.11", -) - -new_http_archive( - name = "paper_checkbox", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-checkbox/archive/v1.4.0.tar.gz", - strip_prefix = "paper-checkbox-1.4.0", -) - -new_http_archive( - name = "paper_dialog", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-dialog/archive/v1.0.4.tar.gz", - strip_prefix = "paper-dialog-1.0.4", -) - -new_http_archive( - name = "paper_dialog_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-dialog-behavior/archive/v1.2.5.tar.gz", - strip_prefix = "paper-dialog-behavior-1.2.5", -) - -new_http_archive( - name = "paper_dialog_scrollable", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-dialog-scrollable/archive/1.1.5.tar.gz", - strip_prefix = "paper-dialog-scrollable-1.1.5", -) - -new_http_archive( - name = "paper_dropdown_menu", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-dropdown-menu/archive/v1.4.0.tar.gz", - strip_prefix = "paper-dropdown-menu-1.4.0", -) - -new_http_archive( - name = "paper_header_panel", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-header-panel/archive/v1.1.4.tar.gz", - strip_prefix = "paper-header-panel-1.1.4", -) - -new_http_archive( - name = "paper_icon_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-icon-button/archive/v1.1.3.tar.gz", - strip_prefix = "paper-icon-button-1.1.3", -) - -new_http_archive( - name = "paper_input", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-input/archive/v1.1.18.tar.gz", - strip_prefix = "paper-input-1.1.18", -) - -new_http_archive( - name = "paper_item", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-item/archive/v1.1.4.tar.gz", - strip_prefix = "paper-item-1.1.4", -) - -new_http_archive( - name = "paper_listbox", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-listbox/archive/v1.1.2.tar.gz", - strip_prefix = "paper-listbox-1.1.2", -) - -new_http_archive( - name = "paper_material", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-material/archive/v1.0.6.tar.gz", - strip_prefix = "paper-material-1.0.6", -) - -new_http_archive( - name = "paper_menu", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-menu/archive/v1.2.2.tar.gz", - strip_prefix = "paper-menu-1.2.2", -) - -new_http_archive( - name = "paper_menu_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-menu-button/archive/v1.5.1.tar.gz", - strip_prefix = "paper-menu-button-1.5.1", -) - -new_http_archive( - name = "paper_progress", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-progress/archive/v1.0.9.tar.gz", - strip_prefix = "paper-progress-1.0.9", -) - -new_http_archive( - name = "paper_radio_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-radio-button/archive/v1.1.2.tar.gz", - strip_prefix = "paper-radio-button-1.1.2", -) - -new_http_archive( - name = "paper_radio_group", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-radio-group/archive/v1.0.9.tar.gz", - strip_prefix = "paper-radio-group-1.0.9", -) - -new_http_archive( - name = "paper_ripple", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-ripple/archive/v1.0.5.tar.gz", - strip_prefix = "paper-ripple-1.0.5", -) - -new_http_archive( - name = "paper_slider", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-slider/archive/v1.0.10.tar.gz", - strip_prefix = "paper-slider-1.0.10", -) - -new_http_archive( - name = "paper_spinner", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-spinner/archive/v1.1.1.tar.gz", - strip_prefix = "paper-spinner-1.1.1", -) - -new_http_archive( - name = "paper_styles", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-styles/archive/v1.1.4.tar.gz", - strip_prefix = "paper-styles-1.1.4", -) - -new_http_archive( - name = "paper_tabs", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-tabs/archive/v1.7.0.tar.gz", - strip_prefix = "paper-tabs-1.7.0", -) - -new_http_archive( - name = "paper_toast", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-toast/archive/v1.3.0.tar.gz", - strip_prefix = "paper-toast-1.3.0", -) - -new_http_archive( - name = "paper_toggle_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-toggle-button/archive/v1.2.0.tar.gz", - strip_prefix = "paper-toggle-button-1.2.0", -) - -new_http_archive( - name = "paper_toolbar", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-toolbar/archive/v1.1.4.tar.gz", - strip_prefix = "paper-toolbar-1.1.4", -) - -new_http_archive( - name = "paper_tooltip", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-tooltip/archive/v1.1.2.tar.gz", - strip_prefix = "paper-tooltip-1.1.2", -) - -new_http_archive( - name = "plottable", - build_file = "bower.BUILD", - url = "https://github.com/palantir/plottable/archive/v1.16.1.tar.gz", - strip_prefix = "plottable-1.16.1", -) - -new_http_archive( - name = "polymer_archive", - build_file = "bower.BUILD", - url = "https://github.com/polymer/polymer/archive/v1.7.0.tar.gz", - strip_prefix = "polymer-1.7.0", -) - -new_http_archive( - name = "promise_polyfill", - build_file = "bower.BUILD", - url = "https://github.com/polymerlabs/promise-polyfill/archive/v1.0.0.tar.gz", - strip_prefix = "promise-polyfill-1.0.0", -) - -http_file( - name = "three_js_three_min_js", - url = "https://raw.githubusercontent.com/mrdoob/three.js/r77/build/three.min.js", -) - -http_file( - name = "three_js_orbitcontrols_js", - url = "https://raw.githubusercontent.com/mrdoob/three.js/r77/examples/js/controls/OrbitControls.js", -) - -new_http_archive( - name = "web_animations_js", - build_file = "bower.BUILD", - url = "https://github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz", - strip_prefix = "web-animations-js-2.2.1", -) - -new_http_archive( - name = "webcomponentsjs", - build_file = "bower.BUILD", - url = "https://github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz", - strip_prefix = "webcomponentsjs-0.7.22", -) - -http_file( - name = "weblas_weblas_js", - url = "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js", + name = "stylize", + build_file = "models.BUILD", + sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa", + urls = [ + "http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", + "http://download.tensorflow.org/models/stylize_v1.zip", + ], ) diff --git a/bower.BUILD b/bower.BUILD deleted file mode 100644 index eabd1d64507..00000000000 --- a/bower.BUILD +++ /dev/null @@ -1,645 +0,0 @@ -# AUTOGENERATED FILE by tensorboard_bower_dependency_sync.py - -package(default_visibility = ["//visibility:public"]) - -filegroup( - name = "d3", - srcs = [ - "d3.js", - "d3.min.js", - "package.js", - ], -) - -filegroup( - name = "dagre", - srcs = [ - "dist/dagre.core.js", - "dist/dagre.core.min.js", - ], -) - -filegroup( - name = "es6_promise", - srcs = [ - "promise.js", - "promise.min.js", - ], -) - -filegroup( - name = "font_roboto", - srcs = ["roboto.html"], -) - -filegroup( - name = "graphlib", - srcs = [ - "dist/graphlib.core.js", - "dist/graphlib.core.min.js", - ], -) - -filegroup( - name = "iron_a11y_announcer", - srcs = [ - "index.html", - "iron-a11y-announcer.html", - ], -) - -filegroup( - name = "iron_a11y_keys_behavior", - srcs = [ - "index.html", - "iron-a11y-keys-behavior.html", - ], -) - -filegroup( - name = "iron_ajax", - srcs = [ - "index.html", - "iron-ajax.html", - "iron-request.html", - ], -) - -filegroup( - name = "iron_autogrow_textarea", - srcs = [ - "index.html", - "iron-autogrow-textarea.html", - ], -) - -filegroup( - name = "iron_behaviors", - srcs = [ - "index.html", - "iron-button-state.html", - "iron-control-state.html", - ], -) - -filegroup( - name = "iron_checked_element_behavior", - srcs = [ - "index.html", - "iron-checked-element-behavior.html", - ], -) - -filegroup( - name = "iron_collapse", - srcs = [ - "index.html", - "iron-collapse.html", - ], -) - -filegroup( - name = "iron_dropdown", - srcs = [ - "index.html", - "iron-dropdown.html", - "iron-dropdown-scroll-manager.html", - ], -) - -filegroup( - name = "iron_fit_behavior", - srcs = [ - "index.html", - "iron-fit-behavior.html", - ], -) - -filegroup( - name = "iron_flex_layout", - srcs = [ - "classes/iron-flex-layout.html", - "classes/iron-shadow-flex-layout.html", - "index.html", - "iron-flex-layout.html", - "iron-flex-layout-classes.html", - ], -) - -filegroup( - name = "iron_form_element_behavior", - srcs = [ - "index.html", - "iron-form-element-behavior.html", - ], -) - -filegroup( - name = "iron_icon", - srcs = [ - "index.html", - "iron-icon.html", - ], -) - -filegroup( - name = "iron_icons", - srcs = [ - "av-icons.html", - "communication-icons.html", - "device-icons.html", - "editor-icons.html", - "hardware-icons.html", - "image-icons.html", - "index.html", - "iron-icons.html", - "maps-icons.html", - "notification-icons.html", - "places-icons.html", - "social-icons.html", - ], -) - -filegroup( - name = "iron_iconset_svg", - srcs = [ - "index.html", - "iron-iconset-svg.html", - ], -) - -filegroup( - name = "iron_input", - srcs = [ - "index.html", - "iron-input.html", - ], -) - -filegroup( - name = "iron_list", - srcs = [ - "index.html", - "iron-list.html", - "test/smoke/avg-worst-case.html", - "test/smoke/dummy-data.html", - "test/smoke/index.html", - "test/smoke/physical-count.html", - ], -) - -filegroup( - name = "iron_menu_behavior", - srcs = [ - "index.html", - "iron-menu-behavior.html", - "iron-menubar-behavior.html", - ], -) - -filegroup( - name = "iron_meta", - srcs = [ - "index.html", - "iron-meta.html", - ], -) - -filegroup( - name = "iron_overlay_behavior", - srcs = [ - "index.html", - "iron-focusables-helper.html", - "iron-overlay-backdrop.html", - "iron-overlay-behavior.html", - "iron-overlay-manager.html", - ], -) - -filegroup( - name = "iron_range_behavior", - srcs = [ - "index.html", - "iron-range-behavior.html", - ], -) - -filegroup( - name = "iron_resizable_behavior", - srcs = [ - "demo/src/x-app.html", - "index.html", - "iron-resizable-behavior.html", - ], -) - -filegroup( - name = "iron_scroll_target_behavior", - srcs = [ - "index.html", - "iron-scroll-target-behavior.html", - ], -) - -filegroup( - name = "iron_selector", - srcs = [ - "index.html", - "iron-multi-selectable.html", - "iron-selectable.html", - "iron-selection.html", - "iron-selector.html", - ], -) - -filegroup( - name = "iron_validatable_behavior", - srcs = [ - "index.html", - "iron-validatable-behavior.html", - ], -) - -filegroup( - name = "lodash", - srcs = [ - "lodash.js", - "lodash.min.js", - ], -) - -filegroup( - name = "neon_animation", - srcs = [ - "animations/cascaded-animation.html", - "animations/fade-in-animation.html", - "animations/fade-out-animation.html", - "animations/hero-animation.html", - "animations/opaque-animation.html", - "animations/reverse-ripple-animation.html", - "animations/ripple-animation.html", - "animations/scale-down-animation.html", - "animations/scale-up-animation.html", - "animations/slide-down-animation.html", - "animations/slide-from-bottom-animation.html", - "animations/slide-from-left-animation.html", - "animations/slide-from-right-animation.html", - "animations/slide-from-top-animation.html", - "animations/slide-left-animation.html", - "animations/slide-right-animation.html", - "animations/slide-up-animation.html", - "animations/transform-animation.html", - "demo/card/index.html", - "demo/card/x-card.html", - "demo/card/x-cards-list.html", - "demo/declarative/index.html", - "demo/doc/index.html", - "demo/doc/my-animatable.html", - "demo/doc/my-dialog.html", - "demo/dropdown/animated-dropdown.html", - "demo/dropdown/index.html", - "demo/grid/animated-grid.html", - "demo/grid/fullsize-page-with-card.html", - "demo/grid/index.html", - "demo/list/full-view.html", - "demo/list/index.html", - "demo/list/list-demo.html", - "demo/list/list-view.html", - "demo/load/animated-grid.html", - "demo/load/full-page.html", - "demo/load/index.html", - "demo/reprojection/animated-grid.html", - "demo/reprojection/fullsize-page-with-card.html", - "demo/reprojection/index.html", - "demo/reprojection/reprojected-pages.html", - "demo/tiles/circles-page.html", - "demo/tiles/index.html", - "demo/tiles/squares-page.html", - "index.html", - "neon-animatable.html", - "neon-animatable-behavior.html", - "neon-animated-pages.html", - "neon-animation.html", - "neon-animation-behavior.html", - "neon-animation-runner-behavior.html", - "neon-animations.html", - "neon-shared-element-animatable-behavior.html", - "neon-shared-element-animation-behavior.html", - "web-animations.html", - ], -) - -filegroup( - name = "paper_behaviors", - srcs = [ - "index.html", - "paper-button-behavior.html", - "paper-checked-element-behavior.html", - "paper-inky-focus-behavior.html", - "paper-ripple-behavior.html", - ], -) - -filegroup( - name = "paper_button", - srcs = [ - "index.html", - "paper-button.html", - ], -) - -filegroup( - name = "paper_checkbox", - srcs = [ - "index.html", - "paper-checkbox.html", - ], -) - -filegroup( - name = "paper_dialog", - srcs = [ - "index.html", - "paper-dialog.html", - ], -) - -filegroup( - name = "paper_dialog_behavior", - srcs = [ - "index.html", - "paper-dialog-behavior.html", - "paper-dialog-common.css", - "paper-dialog-shared-styles.html", - ], -) - -filegroup( - name = "paper_dialog_scrollable", - srcs = [ - "index.html", - "paper-dialog-scrollable.html", - ], -) - -filegroup( - name = "paper_dropdown_menu", - srcs = [ - "index.html", - "paper-dropdown-menu.html", - "paper-dropdown-menu-icons.html", - "paper-dropdown-menu-light.html", - "paper-dropdown-menu-shared-styles.html", - ], -) - -filegroup( - name = "paper_header_panel", - srcs = [ - "index.html", - "paper-header-panel.html", - ], -) - -filegroup( - name = "paper_icon_button", - srcs = [ - "index.html", - "paper-icon-button.html", - "paper-icon-button-light.html", - ], -) - -filegroup( - name = "paper_input", - srcs = [ - "all-imports.html", - "index.html", - "paper-input.html", - "paper-input-addon-behavior.html", - "paper-input-behavior.html", - "paper-input-char-counter.html", - "paper-input-container.html", - "paper-input-error.html", - "paper-textarea.html", - ], -) - -filegroup( - name = "paper_item", - srcs = [ - "all-imports.html", - "index.html", - "paper-icon-item.html", - "paper-item.html", - "paper-item-behavior.html", - "paper-item-body.html", - "paper-item-shared-styles.html", - ], -) - -filegroup( - name = "paper_listbox", - srcs = [ - "index.html", - "paper-listbox.html", - ], -) - -filegroup( - name = "paper_material", - srcs = [ - "index.html", - "paper-material.html", - "paper-material-shared-styles.html", - ], -) - -filegroup( - name = "paper_menu", - srcs = [ - "index.html", - "paper-menu.html", - "paper-menu-shared-styles.html", - "paper-submenu.html", - ], -) - -filegroup( - name = "paper_menu_button", - srcs = [ - "index.html", - "paper-menu-button.html", - "paper-menu-button-animations.html", - ], -) - -filegroup( - name = "paper_progress", - srcs = [ - "index.html", - "paper-progress.html", - ], -) - -filegroup( - name = "paper_radio_button", - srcs = [ - "index.html", - "paper-radio-button.html", - ], -) - -filegroup( - name = "paper_radio_group", - srcs = [ - "index.html", - "paper-radio-group.html", - ], -) - -filegroup( - name = "paper_ripple", - srcs = [ - "index.html", - "paper-ripple.html", - ], -) - -filegroup( - name = "paper_slider", - srcs = [ - "index.html", - "paper-slider.html", - ], -) - -filegroup( - name = "paper_spinner", - srcs = [ - "index.html", - "paper-spinner.html", - "paper-spinner-behavior.html", - "paper-spinner-lite.html", - "paper-spinner-styles.html", - ], -) - -filegroup( - name = "paper_styles", - srcs = [ - "classes/global.html", - "classes/shadow.html", - "classes/shadow-layout.html", - "classes/typography.html", - "color.html", - "default-theme.html", - "demo.css", - "demo-pages.html", - "index.html", - "paper-styles.html", - "paper-styles-classes.html", - "shadow.html", - "typography.html", - ], -) - -filegroup( - name = "paper_tabs", - srcs = [ - "index.html", - "paper-tab.html", - "paper-tabs.html", - "paper-tabs-icons.html", - ], -) - -filegroup( - name = "paper_toast", - srcs = [ - "index.html", - "paper-toast.html", - ], -) - -filegroup( - name = "paper_toggle_button", - srcs = [ - "index.html", - "paper-toggle-button.html", - ], -) - -filegroup( - name = "paper_toolbar", - srcs = [ - "index.html", - "paper-toolbar.html", - ], -) - -filegroup( - name = "paper_tooltip", - srcs = [ - "index.html", - "paper-tooltip.html", - ], -) - -filegroup( - name = "plottable", - srcs = [ - "plottable.css", - "plottable.js", - "plottable.min.js", - ], -) - -filegroup( - name = "polymer", - srcs = [ - "polymer.html", - "polymer-micro.html", - "polymer-mini.html", - ], -) - -filegroup( - name = "promise_polyfill", - srcs = [ - "Gruntfile.js", - "Promise.js", - "Promise.min.js", - "Promise-Statics.js", - "promise-polyfill.html", - "promise-polyfill-lite.html", - ], -) - -filegroup( - name = "web_animations_js", - srcs = [ - "web-animations.html", - "web-animations.min.js", - "web-animations-next.min.js", - "web-animations-next-lite.min.js", - ], -) - -filegroup( - name = "webcomponentsjs", - srcs = [ - "CustomElements.js", - "CustomElements.min.js", - "HTMLImports.js", - "HTMLImports.min.js", - "MutationObserver.js", - "MutationObserver.min.js", - "ShadowDOM.js", - "ShadowDOM.min.js", - "webcomponents.js", - "webcomponents.min.js", - "webcomponents-lite.js", - "webcomponents-lite.min.js", - ], -) diff --git a/configure b/configure index 4f1dc2a9102..602124225fe 100755 --- a/configure +++ b/configure @@ -3,6 +3,8 @@ set -e set -o pipefail +MIN_BAZEL_VERSION=0.4.5 + # Find out the absolute path to where ./configure resides pushd `dirname $0` > /dev/null SOURCE_BASE_DIR=`pwd -P` @@ -11,40 +13,179 @@ popd > /dev/null PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" function is_linux() { - if [[ "${PLATFORM}" == "linux" ]]; then - true - else - false - fi + [[ "${PLATFORM}" == "linux" ]] } function is_macos() { - if [[ "${PLATFORM}" == "darwin" ]]; then - true - else - false - fi + [[ "${PLATFORM}" == "darwin" ]] } function is_windows() { # On windows, the shell script is actually running in msys - if [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]]; then - true - else - false - fi + [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]] } -function bazel_clean_and_fetch() { - # bazel clean --expunge currently doesn't work on Windows - # TODO(pcloudy): Re-enable it after bazel clean --expunge is fixed. - if ! is_windows; then - bazel clean --expunge - fi - bazel fetch "//tensorflow/... -//tensorflow/contrib/nccl/... \ - -//tensorflow/examples/android/..." +function sed_in_place() { + sed -e $1 $2 > "$2.bak" + mv "$2.bak" $2 } +function write_to_bazelrc() { + echo "$1" >> .tf_configure.bazelrc +} + +function write_action_env_to_bazelrc() { + write_to_bazelrc "build --action_env $1=\"$2\"" +} + +function python_path { + "$PYTHON_BIN_PATH" - <&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + PYTHON_BIN_PATH="" + # Retry + done + + if [ -z "$PYTHON_LIB_PATH" ]; then + # Split python_path into an array of paths, this allows path containing spaces + IFS=',' read -r -a python_lib_path <<< "$(python_path)" + + if [ 1 = "$USE_DEFAULT_PYTHON_LIB_PATH" ]; then + PYTHON_LIB_PATH=${python_lib_path[0]} + echo "Using python library path: $PYTHON_LIB_PATH" + + else + echo "Found possible Python library paths:" + for x in "${python_lib_path[@]}"; do + echo " $x" + done + set -- "${python_lib_path[@]}" + echo "Please input the desired Python library path to use. Default is [$1]" + read b || true + if [ "$b" == "" ]; then + PYTHON_LIB_PATH=${python_lib_path[0]} + echo "Using python library path: $PYTHON_LIB_PATH" + else + PYTHON_LIB_PATH="$b" + fi + fi + fi + + if [ ! -x "$PYTHON_BIN_PATH" ] || [ -d "$PYTHON_BIN_PATH" ]; then + echo "PYTHON_BIN_PATH is not executable. Is it the python binary?" + exit 1 + fi + + local python_major_version + python_major_version=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import sys; print(sys.version_info[0]);' | head -c1) + if [ -z "$python_major_version" ]; then + echo -e "\n\nERROR: Problem getting python version. Is $PYTHON_BIN_PATH the correct python binary?" + exit 1 + fi + + # Convert python path to Windows style before writing into bazel.rc + if is_windows; then + PYTHON_BIN_PATH="$(cygpath -m "$PYTHON_BIN_PATH")" + PYTHON_LIB_PATH="$(cygpath -m "$PYTHON_LIB_PATH")" + fi + + # Set-up env variables used by python_configure.bzl + write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH" + write_action_env_to_bazelrc "PYTHON_LIB_PATH" "$PYTHON_LIB_PATH" + write_to_bazelrc "build --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "build --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" + write_to_bazelrc "build --force_python=py$python_major_version" + write_to_bazelrc "build --host_force_python=py$python_major_version" + write_to_bazelrc "build --python${python_major_version}_path=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "test --force_python=py$python_major_version" + write_to_bazelrc "test --host_force_python=py$python_major_version" + write_to_bazelrc "test --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "test --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" + write_to_bazelrc "run --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "run --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" + + # Write tools/python_bin_path.sh + echo "export PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" > tools/python_bin_path.sh +} + +function version { + echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'; +} + + +bazel version > bazel.version +curr_bazel_version=$(head -n 1 bazel.version | cut -d ' ' -f3) +rm -f bazel.version + + +echo "You have bazel $curr_bazel_version installed." +if [ -z "$curr_bazel_version" ]; then + echo "WARNING: current bazel installation is not a release version." + echo "Make sure you are running at least bazel $MIN_BAZEL_VERSION." +elif [ "$(version "$MIN_BAZEL_VERSION")" -gt "$(version "$curr_bazel_version")" ]; then + echo "Please upgrade your bazel installation to version $MIN_BAZEL_VERSION or higher to build TensorFlow!" + echo "Exiting..." + exit 1 +fi + +# This file contains customized config settings. +rm -f .tf_configure.bazelrc +touch .tf_configure.bazelrc +if [[ ! -e .bazelrc ]]; then + if [[ -e "${HOME}/.bazelrc" ]]; then + echo "import ${HOME}/.bazelrc" >.bazelrc + else + touch .bazelrc + fi +fi +sed_in_place "/tf_configure/d" .bazelrc +echo "import %workspace%/.tf_configure.bazelrc" >> .bazelrc + # Delete any leftover BUILD files from the Makefile build, which would interfere # with Bazel parsing. MAKEFILE_DOWNLOAD_DIR=tensorflow/contrib/makefile/downloads @@ -52,58 +193,65 @@ if [ -d "${MAKEFILE_DOWNLOAD_DIR}" ]; then find ${MAKEFILE_DOWNLOAD_DIR} -type f -name '*BUILD' -delete fi -## Set up python-related environment settings -while true; do - fromuser="" - if [ -z "$PYTHON_BIN_PATH" ]; then - default_python_bin_path=$(which python || which python3 || true) - read -p "Please specify the location of python. [Default is $default_python_bin_path]: " PYTHON_BIN_PATH - fromuser="1" - if [ -z "$PYTHON_BIN_PATH" ]; then - PYTHON_BIN_PATH=$default_python_bin_path - fi - fi - if [ -e "$PYTHON_BIN_PATH" ]; then - break - fi - echo "Invalid python path. ${PYTHON_BIN_PATH} cannot be found" 1>&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - PYTHON_BIN_PATH="" - # Retry -done +setup_python ## Set up MKL related environment settings -if false; then # Disable building with MKL for now - while [ "$TF_NEED_MKL" == "" ]; do +while [ "$TF_NEED_MKL" == "" ]; do + fromuser="" + read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT + fromuser="1" + case $INPUT in + [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;; + [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; + "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; + * ) echo "Invalid selection: " $INPUT;; + esac +done + +OSNAME=`uname -s` + +if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL + while [ "$TF_DOWNLOAD_MKL" == "" ]; do fromuser="" - read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT + read -p "Do you wish to download MKL LIB from the web? [Y/n] " INPUT fromuser="1" case $INPUT in - [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;; - [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - * ) echo "Invalid selection: " $INPUT;; + [Yy]* ) TF_DOWNLOAD_MKL=1;; + [Nn]* ) TF_DOWNLOAD_MKL=0;; + "" ) TF_DOWNLOAD_MKL=1;; + * ) echo "Invalid selection: " $INPUT; exit 1;; esac done - OSNAME=`uname -s` - - if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL + if [[ "$TF_DOWNLOAD_MKL" == "1" ]]; then DST=`dirname $0` - ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170110.tgz - GITHUB_RELEASE_TAG=v0.3 + ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz + GITHUB_RELEASE_TAG=v0.7 MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME" - if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then - wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL + if ! [ -e "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" ]; then + curl -fSsL -o "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" "${MKLURL}" fi tar -xzf $DST/third_party/mkl/$ARCHIVE_BASENAME -C $DST/third_party/mkl/ extracted_dir_name="${ARCHIVE_BASENAME%.*}" MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` - if [ "$OSNAME" == "Linux" ]; then + else + default_mkl_path=/opt/intel/mklml + fromuser="" + if [ -z "$MKL_INSTALL_PATH" ]; then + read -p "Please specify the location where MKL is installed. [Default is $default_mkl_path]: " MKL_INSTALL_PATH + fromuser="1" + fi + if [ -z "$MKL_INSTALL_PATH" ]; then + MKL_INSTALL_PATH=$default_mkl_path + fi + # Result returned from "read" will be used unexpanded. That make "~" unusable. + # Going through one more level of expansion to handle that. + MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` + fi + + if [ "$OSNAME" == "Linux" ]; then # Full MKL configuration MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version? MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION? @@ -111,24 +259,29 @@ if false; then # Disable building with MKL for now # MKL-ML configuration MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version? MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION? - elif [ "$OSNAME" == "Darwin" ]; then + elif [ "$OSNAME" == "Darwin" ]; then echo "Darwin is unsupported yet"; exit 1 - fi + fi - if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then + if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/ ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/ ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include - else - echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist"; + loc=$(locate -e libdl.so.2 | sed -n 1p) + ln -sf $loc third_party/mkl/libdl.so.2 + elif [ -e "$MKL_INSTALL_PATH/${MKL_RT_LIB_PATH}" ]; then + ln -sf $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/${MKL_RT_OMP_LIB_PATH} third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include + loc=$(locate -e libdl.so.2 | sed -n 1p) + ln -sf $loc third_party/mkl/libdl.so.2 + else + echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} nor $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} exists"; exit 1 - fi - - if [ -z "$fromuser" ]; then - exit 1 - fi + fi cat > third_party/mkl/mkl.config < third_party/mkl/mkl.config <> tools/bazel.rc for opt in $CC_OPT_FLAGS; do - echo "build:opt --cxxopt=$opt --copt=$opt" >> tools/bazel.rc + write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt" done # Run the gen_git_source to create links where bazel can track dependencies for @@ -289,35 +434,46 @@ while [ "$TF_NEED_CUDA" == "" ]; do done export TF_NEED_CUDA +write_action_env_to_bazelrc "TF_NEED_CUDA" "$TF_NEED_CUDA" + export TF_NEED_OPENCL -if [[ "$TF_NEED_CUDA" == "0" ]] && [[ "$TF_NEED_OPENCL" == "0" ]]; then - echo "Configuration finished" - bazel_clean_and_fetch - exit -fi +write_action_env_to_bazelrc "TF_NEED_OPENCL" "$TF_NEED_OPENCL" if [ "$TF_NEED_CUDA" == "1" ]; then -# Set up which gcc nvcc should use as the host compiler -# No need to set this on Windows -while ! is_windows && true; do +while [[ "$TF_CUDA_CLANG" == "" ]]; do + read -p "Do you want to use clang as CUDA compiler? [y/N] " INPUT + case $INPUT in + [Yy]* ) echo "Clang will be used as CUDA compiler"; TF_CUDA_CLANG=1;; + [Nn]* ) echo "nvcc will be used as CUDA compiler"; TF_CUDA_CLANG=0;; + "" ) echo "nvcc will be used as CUDA compiler"; TF_CUDA_CLANG=0;; + * ) echo "Invalid selection: " $INPUT;; + esac +done + +export TF_CUDA_CLANG +write_action_env_to_bazelrc "TF_CUDA_CLANG" "$TF_CUDA_CLANG" + +# Set up which clang we should use as the cuda / host compiler. +while [[ "$TF_CUDA_CLANG" == "1" ]] && true; do fromuser="" - if [ -z "$GCC_HOST_COMPILER_PATH" ]; then - default_gcc_host_compiler_path=$(which gcc || true) - read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH + if [ -z "$CLANG_CUDA_COMPILER_PATH" ]; then + default_clang_host_compiler_path=$(which clang || true) + read -p "Please specify which clang should be used as device and host compiler. [Default is $default_clang_host_compiler_path]: " CLANG_CUDA_COMPILER_PATH fromuser="1" - if [ -z "$GCC_HOST_COMPILER_PATH" ]; then - GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path" + if [ -z "$CLANG_CUDA_COMPILER_PATH" ]; then + CLANG_CUDA_COMPILER_PATH="$default_clang_host_compiler_path" fi fi - if [ -e "$GCC_HOST_COMPILER_PATH" ]; then - export GCC_HOST_COMPILER_PATH + if [ -e "$CLANG_CUDA_COMPILER_PATH" ]; then + export CLANG_CUDA_COMPILER_PATH + write_action_env_to_bazelrc "CLANG_CUDA_COMPILER_PATH" "$CLANG_CUDA_COMPILER_PATH" break fi - echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2 + echo "Invalid clang path. ${CLANG_CUDA_COMPILER_PATH} cannot be found" 1>&2 if [ -z "$fromuser" ]; then exit 1 fi - GCC_HOST_COMPILER_PATH="" + CLANG_CUDA_COMPILER_PATH="" # Retry done @@ -325,7 +481,7 @@ done while true; do # Configure the Cuda SDK version to use. if [ -z "$TF_CUDA_VERSION" ]; then - read -p "Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to use system default]: " TF_CUDA_VERSION + read -p "Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 8.0]: " TF_CUDA_VERSION fi fromuser="" @@ -337,6 +493,11 @@ while true; do else default_cuda_path="$(cygpath -m "$CUDA_PATH")" fi + elif is_linux; then + # If the default doesn't exist, try an alternative default. + if [ ! -d $default_cuda_path ] && [ -d /opt/cuda ]; then + default_cuda_path=/opt/cuda + fi fi read -p "Please specify the location where CUDA $TF_CUDA_VERSION toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH fromuser="1" @@ -361,6 +522,7 @@ while true; do if [ -e "${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH}" ]; then export CUDA_TOOLKIT_PATH + write_action_env_to_bazelrc "CUDA_TOOLKIT_PATH" "$CUDA_TOOLKIT_PATH" export TF_CUDA_VERSION break fi @@ -374,11 +536,47 @@ while true; do CUDA_TOOLKIT_PATH="" done +# Set default CUDA version if not set +if [ -z "$TF_CUDA_VERSION" ]; then + TF_CUDA_VERSION="8.0" + export TF_CUDA_VERSION +fi +write_action_env_to_bazelrc "TF_CUDA_VERSION" "$TF_CUDA_VERSION" + +# Set up which gcc nvcc should use as the host compiler +# No need to set this on Windows +while [[ "$TF_CUDA_CLANG" != "1" ]] && ! is_windows && true; do + fromuser="" + if [ -z "$GCC_HOST_COMPILER_PATH" ]; then + default_gcc_host_compiler_path=$(which gcc || true) + cuda_bin_symlink="$CUDA_TOOLKIT_PATH/bin/gcc" + if [ -L "$cuda_bin_symlink" ]; then + default_gcc_host_compiler_path=$(readlink $cuda_bin_symlink) + fi + read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH + fromuser="1" + if [ -z "$GCC_HOST_COMPILER_PATH" ]; then + GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path" + fi + fi + if [ -e "$GCC_HOST_COMPILER_PATH" ]; then + export GCC_HOST_COMPILER_PATH + write_action_env_to_bazelrc "GCC_HOST_COMPILER_PATH" "$GCC_HOST_COMPILER_PATH" + break + fi + echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + GCC_HOST_COMPILER_PATH="" + # Retry +done + # Find out where the cuDNN library is installed while true; do - # Configure the Cudnn version to use. + # Configure the cuDNN version to use. if [ -z "$TF_CUDNN_VERSION" ]; then - read -p "Please specify the Cudnn version you want to use. [Leave empty to use system default]: " TF_CUDNN_VERSION + read -p "Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 6.0]: " TF_CUDNN_VERSION fi fromuser="" @@ -389,7 +587,7 @@ while true; do if [ -z "$CUDNN_INSTALL_PATH" ]; then CUDNN_INSTALL_PATH=$default_cudnn_path fi - # Result returned from "read" will be used unexpanded. That make "~" unuseable. + # Result returned from "read" will be used unexpanded. That make "~" unusable. # Going through one more level of expansion to handle that. CUDNN_INSTALL_PATH=`"${PYTHON_BIN_PATH}" -c "import os; print(os.path.realpath(os.path.expanduser('${CUDNN_INSTALL_PATH}')))"` fi @@ -411,17 +609,26 @@ while true; do CUDA_DNN_LIB_ALT_PATH="libcudnn${TF_CUDNN_EXT}.dylib" fi - if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" -o -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then + if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" ] || [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then export TF_CUDNN_VERSION + write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION" export CUDNN_INSTALL_PATH + write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "$CUDNN_INSTALL_PATH" break fi if is_linux; then - CUDNN_PATH_FROM_LDCONFIG="$(ldconfig -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')" + if ! type ldconfig > /dev/null 2>&1; then + LDCONFIG_BIN=/sbin/ldconfig + else + LDCONFIG_BIN=ldconfig + fi + CUDNN_PATH_FROM_LDCONFIG="$($LDCONFIG_BIN -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')" if [ -e "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}" ]; then export TF_CUDNN_VERSION - export CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})" + export CUDNN_INSTALL_PATH + CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})" + write_action_env_to_bazelrc "CUDNN_INSTALL_PATH" "$CUDNN_INSTALL_PATH" break fi fi @@ -440,6 +647,13 @@ while true; do CUDNN_INSTALL_PATH="" done +# Set default CUDNN version if not set +if [ -z "$TF_CUDNN_VERSION" ]; then + TF_CUDNN_VERSION="6" + export TF_CUDNN_VERSION +fi +write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION" + # Configure the compute capabilities that TensorFlow builds for. # Since Cuda toolkit is not backward-compatible, this is not guaranteed to work. while true; do @@ -473,6 +687,7 @@ EOF fi else export TF_CUDA_COMPUTE_CAPABILITIES + write_action_env_to_bazelrc "TF_CUDA_COMPUTE_CAPABILITIES" "$TF_CUDA_COMPUTE_CAPABILITIES" break fi TF_CUDA_COMPUTE_CAPABILITIES="" @@ -483,9 +698,15 @@ if is_windows; then export CUDA_PATH="$CUDA_TOOLKIT_PATH" export CUDA_COMPUTE_CAPABILITIES="$TF_CUDA_COMPUTE_CAPABILITIES" export NO_WHOLE_ARCHIVE_OPTION=1 - - # Set GCC_HOST_COMPILER_PATH to keep cuda_configure.bzl happy - export GCC_HOST_COMPILER_PATH="/usr/bin/dummy_compiler" + write_action_env_to_bazelrc "CUDA_PATH" "$CUDA_PATH" + write_action_env_to_bazelrc "CUDA_COMPUTE_CAPABILITIES" "$CUDA_COMPUTE_CAPABILITIES" + write_action_env_to_bazelrc "NO_WHOLE_ARCHIVE_OPTION" "1" + write_to_bazelrc "build --config=win-cuda" + write_to_bazelrc "test --config=win-cuda" +else + # If CUDA is enabled, always use GPU during build and test. + write_to_bazelrc "build --config=cuda" + write_to_bazelrc "test --config=cuda" fi # end of if "$TF_NEED_CUDA" == "1" @@ -499,7 +720,7 @@ if [ "$TF_NEED_OPENCL" == "1" ]; then while true; do fromuser="" if [ -z "$HOST_CXX_COMPILER" ]; then - default_cxx_host_compiler=$(which clang++-3.6 || true) + default_cxx_host_compiler=$(which g++ || true) read -p "Please specify which C++ compiler should be used as the host C++ compiler. [Default is $default_cxx_host_compiler]: " HOST_CXX_COMPILER fromuser="1" if [ -z "$HOST_CXX_COMPILER" ]; then @@ -508,6 +729,7 @@ while true; do fi if [ -e "$HOST_CXX_COMPILER" ]; then export HOST_CXX_COMPILER + write_action_env_to_bazelrc "HOST_CXX_COMPILER" "$HOST_CXX_COMPILER" break fi echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2 @@ -522,7 +744,7 @@ done while true; do fromuser="" if [ -z "$HOST_C_COMPILER" ]; then - default_c_host_compiler=$(which clang-3.6 || true) + default_c_host_compiler=$(which gcc || true) read -p "Please specify which C compiler should be used as the host C compiler. [Default is $default_c_host_compiler]: " HOST_C_COMPILER fromuser="1" if [ -z "$HOST_C_COMPILER" ]; then @@ -531,6 +753,7 @@ while true; do fi if [ -e "$HOST_C_COMPILER" ]; then export HOST_C_COMPILER + write_action_env_to_bazelrc "HOST_C_COMPILER" "$HOST_C_COMPILER" break fi echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2 @@ -561,6 +784,7 @@ while true; do if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then export COMPUTECPP_TOOLKIT_PATH + write_action_env_to_bazelrc "COMPUTECPP_TOOLKIT_PATH" "$COMPUTECPP_TOOLKIT_PATH" break fi echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found" @@ -576,6 +800,82 @@ done # end of if "$TF_NEED_OPENCL" == "1" fi -bazel_clean_and_fetch + +while [ "$TF_NEED_MPI" == "" ]; do + read -p "Do you wish to build TensorFlow with "\ +"MPI support? [y/N] " INPUT + case $INPUT in + [Yy]* ) echo "MPI support will be enabled for "\ +"TensorFlow"; TF_NEED_MPI=1;; + [Nn]* ) echo "MPI support will not be enabled for "\ +"TensorFlow"; TF_NEED_MPI=0;; + "" ) echo "MPI support will not be enabled for "\ +"TensorFlow"; TF_NEED_MPI=0;; + * ) echo "Invalid selection: " $INPUT;; + esac +done + +# Find out where the MPI toolkit is installed +while true; do + if [ "$TF_NEED_MPI" == "0" ]; then + break; + fi + + fromuser="" + if [ -z "$MPI_HOME" ]; then + #Get the base folder by removing the bin path + default_mpi_path=$(dirname $(dirname $(which mpirun)) || dirname $(dirname $(which mpiexec)) || true) + read -p "Please specify the MPI toolkit folder. [Default is $default_mpi_path]: " MPI_HOME + fromuser="1" + if [ -z "$MPI_HOME" ]; then + MPI_HOME=$default_mpi_path + fi + fi + + #Check that the include and library folders are where we expect them to be + if [ -e "$MPI_HOME/include" ] && [ -e "$MPI_HOME/lib" ]; then + break + fi + + echo "Invalid path to the MPI Toolkit. ${MPI_HOME}/include or ${MPI_HOME}/lib cannot be found." + if [ -z "$fromuser" ]; then + exit 1 + fi + + # Retry + MPI_HOME="" +done + + +if [ "$TF_NEED_MPI" == "1" ]; then + write_to_bazelrc 'build --define with_mpi_support=true' + + #Link the MPI header files + ln -sf "${MPI_HOME}/include/mpi.h" third_party/mpi/mpi.h + + + #Determine if we use OpenMPI or MVAPICH, these require different header files + #to be included here to make bazel dependency checker happy + + if [ -e "${MPI_HOME}/include/mpi_portable_platform.h" ]; then + #OpenMPI + ln -sf "${MPI_HOME}/include/mpi_portable_platform.h" third_party/mpi/ + sed -i -e "s/MPI_LIB_IS_OPENMPI=False/MPI_LIB_IS_OPENMPI=True/" third_party/mpi/mpi.bzl + else + #MVAPICH / MPICH + ln -sf "${MPI_HOME}/include/mpio.h" third_party/mpi/ + ln -sf "${MPI_HOME}/include/mpicxx.h" third_party/mpi/ + sed -i -e "s/MPI_LIB_IS_OPENMPI=True/MPI_LIB_IS_OPENMPI=False/" third_party/mpi/mpi.bzl + fi + + + if [ -e "${MPI_HOME}/lib/libmpi.so" ]; then + ln -sf "${MPI_HOME}/lib/libmpi.so" third_party/mpi/ + else + echo "Cannot find the MPI library file in ${MPI_HOME}/lib " + exit 1 + fi +fi + echo "Configuration finished" diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 999e11c0e91..6450b2ad878 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -14,8 +14,33 @@ exports_files([ # Config setting for determining if we are building for Android. config_setting( name = "android", + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_x86", values = { "crosstool_top": "//external:android/crosstool", + "cpu": "x86", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_x86_64", + values = { + "crosstool_top": "//external:android/crosstool", + "cpu": "x86_64", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android_armeabi", + values = { + "cc_target_os": "android", + "cpu": "armeabi", }, visibility = ["//visibility:public"], ) @@ -46,6 +71,12 @@ config_setting( config_setting( name = "windows", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "windows_msvc", values = {"cpu": "x64_windows_msvc"}, visibility = ["//visibility:public"], ) @@ -58,9 +89,7 @@ config_setting( config_setting( name = "ios", - values = { - "crosstool_top": "//tools/osx/crosstool:crosstool", - }, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, visibility = ["//visibility:public"], ) @@ -70,6 +99,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_ppc64le", + values = {"cpu": "ppc"}, + visibility = ["//visibility:public"], +) + config_setting( name = "debug", values = { @@ -86,6 +121,61 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "freebsd", + values = {"cpu": "freebsd"}, + visibility = ["//visibility:public"], +) + +# TODO(jhseu): Enable on other platforms other than Linux. +config_setting( + name = "with_jemalloc_linux_x86_64", + values = { + "cpu": "k8", + "define": "with_jemalloc=true", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_jemalloc_linux_ppc64le", + values = { + "cpu": "ppc", + "define": "with_jemalloc=true", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_gcp_support", + values = {"define": "with_gcp_support=true"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support", + values = {"define": "with_hdfs_support=true"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_xla_support", + values = {"define": "with_xla_support=true"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_verbs_support", + values = {"define": "with_verbs_support=true"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_mpi_support", + values = {"define": "with_mpi_support=true"}, + visibility = ["//visibility:public"], +) + package_group( name = "internal", packages = ["//tensorflow/..."], @@ -123,7 +213,9 @@ filegroup( "//tensorflow/compiler/aot/tests:all_files", "//tensorflow/compiler/jit:all_files", "//tensorflow/compiler/jit/graphcycles:all_files", + "//tensorflow/compiler/jit/kernels:all_files", "//tensorflow/compiler/jit/legacy_flags:all_files", + "//tensorflow/compiler/jit/ops:all_files", "//tensorflow/compiler/tests:all_files", "//tensorflow/compiler/tf2xla:all_files", "//tensorflow/compiler/tf2xla/kernels:all_files", @@ -131,7 +223,6 @@ filegroup( "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", "//tensorflow/compiler/xla/legacy_flags:all_files", - "//tensorflow/compiler/xla/port:all_files", "//tensorflow/compiler/xla/service:all_files", "//tensorflow/compiler/xla/service/cpu:all_files", "//tensorflow/compiler/xla/service/gpu:all_files", @@ -141,11 +232,28 @@ filegroup( "//tensorflow/compiler/xla/tools:all_files", "//tensorflow/contrib:all_files", "//tensorflow/contrib/android:all_files", + "//tensorflow/contrib/batching:all_files", + "//tensorflow/contrib/batching/kernels:all_files", + "//tensorflow/contrib/batching/test_util:all_files", + "//tensorflow/contrib/batching/util:all_files", "//tensorflow/contrib/bayesflow:all_files", + "//tensorflow/contrib/boosted_trees:all_files", + "//tensorflow/contrib/boosted_trees/lib:all_files", + "//tensorflow/contrib/boosted_trees/proto:all_files", + "//tensorflow/contrib/boosted_trees/resources:all_files", + "//tensorflow/contrib/cloud:all_files", + "//tensorflow/contrib/cloud/kernels:all_files", + "//tensorflow/contrib/cluster_resolver:all_files", "//tensorflow/contrib/compiler:all_files", "//tensorflow/contrib/copy_graph:all_files", "//tensorflow/contrib/crf:all_files", "//tensorflow/contrib/cudnn_rnn:all_files", + "//tensorflow/contrib/data:all_files", + "//tensorflow/contrib/data/python/framework:all_files", + "//tensorflow/contrib/data/python/kernel_tests:all_files", + "//tensorflow/contrib/data/python/ops:all_files", + "//tensorflow/contrib/data/python/util:all_files", + "//tensorflow/contrib/decision_trees:all_files", "//tensorflow/contrib/distributions:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", @@ -154,10 +262,15 @@ filegroup( "//tensorflow/contrib/framework:all_files", "//tensorflow/contrib/graph_editor:all_files", "//tensorflow/contrib/grid_rnn:all_files", + "//tensorflow/contrib/hooks:all_files", + "//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files", "//tensorflow/contrib/image:all_files", + "//tensorflow/contrib/imperative:all_files", "//tensorflow/contrib/input_pipeline:all_files", "//tensorflow/contrib/input_pipeline/kernels:all_files", "//tensorflow/contrib/integrate:all_files", + "//tensorflow/contrib/keras:all_files", + "//tensorflow/contrib/kernel_methods:all_files", "//tensorflow/contrib/labeled_tensor:all_files", "//tensorflow/contrib/layers:all_files", "//tensorflow/contrib/layers/kernels:all_files", @@ -172,30 +285,44 @@ filegroup( "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", "//tensorflow/contrib/rnn:all_files", + "//tensorflow/contrib/saved_model:all_files", + "//tensorflow/contrib/saved_model/cc/saved_model:all_files", "//tensorflow/contrib/seq2seq:all_files", "//tensorflow/contrib/session_bundle:all_files", "//tensorflow/contrib/session_bundle/example:all_files", + "//tensorflow/contrib/signal:all_files", "//tensorflow/contrib/slim:all_files", "//tensorflow/contrib/slim/python/slim/data:all_files", "//tensorflow/contrib/slim/python/slim/nets:all_files", "//tensorflow/contrib/solvers:all_files", "//tensorflow/contrib/sparsemax:all_files", "//tensorflow/contrib/specs:all_files", + "//tensorflow/contrib/staging:all_files", "//tensorflow/contrib/stat_summarizer:all_files", + "//tensorflow/contrib/stateless:all_files", "//tensorflow/contrib/tensor_forest:all_files", "//tensorflow/contrib/tensor_forest/hybrid:all_files", "//tensorflow/contrib/tensorboard:all_files", "//tensorflow/contrib/testing:all_files", + "//tensorflow/contrib/text:all_files", "//tensorflow/contrib/tfprof/python/tools/tfprof:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", + "//tensorflow/contrib/verbs:all_files", + "//tensorflow/contrib/xla_tf_graph:all_files", "//tensorflow/core:all_files", "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", "//tensorflow/core/distributed_runtime/rpc:all_files", + "//tensorflow/core/grappler:all_files", + "//tensorflow/core/grappler/clusters:all_files", + "//tensorflow/core/grappler/costs:all_files", + "//tensorflow/core/grappler/inputs:all_files", + "//tensorflow/core/grappler/optimizers:all_files", + "//tensorflow/core/grappler/utils:all_files", "//tensorflow/core/kernels:all_files", - "//tensorflow/core/kernels/cloud:all_files", "//tensorflow/core/kernels/hexagon:all_files", + "//tensorflow/core/kernels/neon:all_files", "//tensorflow/core/ops/compat:all_files", "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", @@ -203,6 +330,7 @@ filegroup( "//tensorflow/core/util/ctc:all_files", "//tensorflow/core/util/tensor_bundle:all_files", "//tensorflow/examples/android:all_files", + "//tensorflow/examples/benchmark:all_files", "//tensorflow/examples/how_tos/reading_data:all_files", "//tensorflow/examples/image_retraining:all_files", "//tensorflow/examples/label_image:all_files", @@ -211,30 +339,87 @@ filegroup( "//tensorflow/examples/tutorials/estimators:all_files", "//tensorflow/examples/tutorials/mnist:all_files", "//tensorflow/examples/tutorials/word2vec:all_files", - "//tensorflow/g3doc/how_tos/adding_an_op:all_files", - "//tensorflow/g3doc/tutorials:all_files", + "//tensorflow/examples/wav_to_spectrogram:all_files", "//tensorflow/go:all_files", "//tensorflow/java:all_files", "//tensorflow/java/src/main/java/org/tensorflow/examples:all_files", "//tensorflow/java/src/main/native:all_files", "//tensorflow/python:all_files", "//tensorflow/python/debug:all_files", + "//tensorflow/python/estimator:all_files", + "//tensorflow/python/feature_column:all_files", "//tensorflow/python/kernel_tests:all_files", + "//tensorflow/python/kernel_tests/distributions:all_files", + "//tensorflow/python/ops/distributions:all_files", "//tensorflow/python/saved_model:all_files", "//tensorflow/python/tools:all_files", "//tensorflow/tensorboard:all_files", - "//tensorflow/tensorboard/app:all_files", "//tensorflow/tensorboard/backend:all_files", + "//tensorflow/tensorboard/backend/event_processing:all_files", "//tensorflow/tensorboard/components:all_files", - "//tensorflow/tensorboard/components/vz_data_summary:all_files", + "//tensorflow/tensorboard/components/tf_audio_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_audio_dashboard/test:all_files", + "//tensorflow/tensorboard/components/tf_backend:all_files", + "//tensorflow/tensorboard/components/tf_backend/test:all_files", + "//tensorflow/tensorboard/components/tf_color_scale:all_files", + "//tensorflow/tensorboard/components/tf_color_scale/test:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common/test:all_files", + "//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_globals:all_files", + "//tensorflow/tensorboard/components/tf_graph:all_files", + "//tensorflow/tensorboard/components/tf_graph/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_app:all_files", + "//tensorflow/tensorboard/components/tf_graph_app/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_board:all_files", + "//tensorflow/tensorboard/components/tf_graph_board/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_common:all_files", + "//tensorflow/tensorboard/components/tf_graph_controls:all_files", + "//tensorflow/tensorboard/components/tf_graph_controls/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_graph_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_debugger_data_card:all_files", + "//tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_info:all_files", + "//tensorflow/tensorboard/components/tf_graph_info/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_loader:all_files", + "//tensorflow/tensorboard/components/tf_graph_loader/demo:all_files", + "//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_image_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_imports:all_files", + "//tensorflow/tensorboard/components/tf_option_selector:all_files", + "//tensorflow/tensorboard/components/tf_profile_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_profile_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_runs_selector:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_storage:all_files", + "//tensorflow/tensorboard/components/tf_storage/test:all_files", + "//tensorflow/tensorboard/components/tf_tensorboard:all_files", + "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_trace_viewer:all_files", + "//tensorflow/tensorboard/components/vz_distribution_chart:all_files", + "//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files", "//tensorflow/tensorboard/components/vz_line_chart:all_files", - "//tensorflow/tensorboard/components/vz_line_chart/demo:all_files", "//tensorflow/tensorboard/components/vz_projector:all_files", + "//tensorflow/tensorboard/components/vz_projector/test:all_files", "//tensorflow/tensorboard/components/vz_sorting:all_files", "//tensorflow/tensorboard/components/vz_sorting/test:all_files", - "//tensorflow/tensorboard/lib:all_files", - "//tensorflow/tensorboard/lib/python:all_files", + "//tensorflow/tensorboard/demo:all_files", + "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", + "//tensorflow/tensorboard/plugins:all_files", + "//tensorflow/tensorboard/plugins/audio:all_files", + "//tensorflow/tensorboard/plugins/distributions:all_files", + "//tensorflow/tensorboard/plugins/graphs:all_files", + "//tensorflow/tensorboard/plugins/histograms:all_files", + "//tensorflow/tensorboard/plugins/images:all_files", + "//tensorflow/tensorboard/plugins/projector:all_files", + "//tensorflow/tensorboard/plugins/scalars:all_files", + "//tensorflow/tensorboard/plugins/text:all_files", "//tensorflow/tensorboard/scripts:all_files", + "//tensorflow/tools/api/golden:all_files", + "//tensorflow/tools/api/lib:all_files", + "//tensorflow/tools/api/tests:all_files", "//tensorflow/tools/common:all_files", "//tensorflow/tools/compatibility:all_files", "//tensorflow/tools/dist_test/server:all_files", @@ -247,6 +432,7 @@ filegroup( "//tensorflow/tools/test:all_files", "//tensorflow/tools/tfprof:all_files", "//tensorflow/tools/tfprof/internal:all_files", + "//tensorflow/tools/tfprof/internal/advisor:all_files", "//tensorflow/user_ops:all_files", "//third_party/hadoop:all_files", "//third_party/sycl:all_files", @@ -269,23 +455,35 @@ filegroup( ), ) +filegroup( + name = "docs_src", + data = glob(["docs_src/**/*.md"]), +) + # ------------------------------------------- # New rules should be added above this target. # ------------------------------------------- cc_binary( name = "libtensorflow.so", + linkopts = select({ + "//tensorflow:darwin": [ + "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file + "//tensorflow/c:exported_symbols.lds", + ], + "//tensorflow:windows": [], + "//tensorflow:windows_msvc": [], + "//conditions:default": [ + "-z defs", + "-s", + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "//tensorflow/c:version_script.lds", + ], + }), linkshared = 1, deps = [ "//tensorflow/c:c_api", - "//tensorflow/core:tensorflow", - ], -) - -cc_binary( - name = "libtensorflow_c.so", - linkshared = 1, - deps = [ - "//tensorflow/c:c_api", + "//tensorflow/c:exported_symbols.lds", + "//tensorflow/c:version_script.lds", "//tensorflow/core:tensorflow", ], ) @@ -296,6 +494,8 @@ cc_binary( deps = [ "//tensorflow/c:c_api", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/cc:scope", "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index 92d390a9764..083634bd796 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -24,16 +24,19 @@ from __future__ import print_function from tensorflow.python import * # pylint: enable=wildcard-import -# Lazily import the `tf.contrib` module. This avoids loading all of the -# dependencies of `tf.contrib` at `import tensorflow` time. -class _LazyContribLoader(object): +from tensorflow.python.util.lazy_loader import LazyLoader +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader - def __getattr__(self, item): - global contrib - # Replace the lazy loader with the imported module itself. - import importlib # pylint: disable=g-import-not-at-top - contrib = importlib.import_module('tensorflow.contrib') - return getattr(contrib, item) +del absolute_import +del division +del print_function - -contrib = _LazyContribLoader() +# These symbols appear because we import the python package which +# in turn imports from tensorflow.core and tensorflow.python. They +# must come from this module. So python adds these symbols for the +# resolution to succeed. +# pylint: disable=undefined-variable +del python +del core +# pylint: enable=undefined-variable diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index e0a4272ee22..3ab4e8efcdb 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -26,6 +26,22 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +tf_cuda_library( + name = "c_api_internal", + srcs = ["c_api.h"], + hdrs = ["c_api_internal.h"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + }), +) + tf_cuda_library( name = "c_api", srcs = ["c_api.cc"], @@ -34,10 +50,16 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ + ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + ":c_api_internal", "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc:gradients", + "//tensorflow/cc:ops", + "//tensorflow/cc:grad_ops", + "//tensorflow/cc:scope_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -45,6 +67,14 @@ tf_cuda_library( }), ) +exports_files( + [ + "version_script.lds", + "exported_symbols.lds", + ], + visibility = ["//visibility:public"], +) + tf_cuda_library( name = "tf_status_helper", srcs = ["tf_status_helper.cc"], @@ -89,20 +119,22 @@ tf_cc_test( # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:grad_ops", "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:math", - "//third_party/eigen3", ], ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index a51a3ca69e9..77faa475ed4 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -21,8 +21,12 @@ limitations under the License. #include #ifndef __ANDROID__ +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/saved_model/loader.h" #endif +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" @@ -52,12 +56,11 @@ limitations under the License. using tensorflow::error::Code; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; +using tensorflow::strings::StrCat; using tensorflow::AllocationDescription; using tensorflow::DataType; -using tensorflow::Env; using tensorflow::Graph; using tensorflow::GraphDef; -using tensorflow::mutex; using tensorflow::mutex_lock; using tensorflow::NameRangeMap; using tensorflow::NameRangesForNode; @@ -68,11 +71,9 @@ using tensorflow::NodeBuilder; using tensorflow::OpDef; using tensorflow::OpRegistry; using tensorflow::PartialTensorShape; -using tensorflow::Reset; using tensorflow::RunMetadata; using tensorflow::RunOptions; using tensorflow::Session; -using tensorflow::SessionOptions; using tensorflow::Status; using tensorflow::Tensor; using tensorflow::TensorBuffer; @@ -92,9 +93,6 @@ size_t TF_DataTypeSize(TF_DataType dt) { } // -------------------------------------------------------------------------- -struct TF_Status { - Status status; -}; TF_Status* TF_NewStatus() { return new TF_Status; } @@ -134,6 +132,9 @@ class TF_ManagedBuffer : public TensorBuffer { proto->set_requested_bytes(rb); proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); } + + // Prevents input forwarding from mutating this buffer. + bool OwnsMemory() const override { return false; } }; void* allocate_tensor(const char* operation, size_t len) { @@ -175,12 +176,6 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, } // namespace -struct TF_Tensor { - TF_DataType dtype; - TensorShape shape; - TensorBuffer* buffer; -}; - TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims, int num_dims, size_t len) { void* data = allocate_tensor("TF_AllocateTensor", len); @@ -216,6 +211,18 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, return new TF_Tensor{dtype, TensorShape(dimvec), buf}; } +TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { + // It is safe to move the Tensor if and only if we own the unique reference to + // it. In that case, we might as well not delete and reallocate, but a future + // implementation might need to do so. + if (tensor->buffer->RefCountIsOne() && + tensor->buffer->root_buffer()->RefCountIsOne() && + tensor->buffer->OwnsMemory()) { + return tensor; + } + return nullptr; +} + void TF_DeleteTensor(TF_Tensor* t) { t->buffer->Unref(); delete t; @@ -273,9 +280,6 @@ size_t TF_StringEncodedSize(size_t len) { } // -------------------------------------------------------------------------- -struct TF_SessionOptions { - SessionOptions options; -}; TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } @@ -316,9 +320,6 @@ void TF_DeleteBuffer(TF_Buffer* buffer) { TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } // -------------------------------------------------------------------------- -struct TF_DeprecatedSession { - Session* session; -}; TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, TF_Status* status) { @@ -328,7 +329,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, return new TF_DeprecatedSession({session}); } else { DCHECK_EQ(nullptr, session); - return NULL; + return nullptr; } } @@ -502,7 +503,7 @@ static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, TF_Status* status) { status->status = Status::OK(); for (int i = 0; i < noutputs; ++i) { - c_outputs[i] = NULL; + c_outputs[i] = nullptr; } } @@ -542,9 +543,8 @@ static void TF_Run_Helper( if (handle == nullptr) { RunOptions run_options_proto; - if (run_options != nullptr && - !run_options_proto.ParseFromArray(run_options->data, - run_options->length)) { + if (run_options != nullptr && !run_options_proto.ParseFromArray( + run_options->data, run_options->length)) { status->status = InvalidArgument("Unparseable RunOptions proto"); return; } @@ -651,6 +651,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; } else { + *handle = nullptr; status->status = result; } } @@ -682,11 +683,6 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle, c_outputs, target_oper_names, nullptr, status); } -struct TF_Library { - void* lib_handle; - TF_Buffer op_list; -}; - TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { TF_Library* lib_handle = new TF_Library; status->status = tensorflow::LoadLibrary( @@ -714,68 +710,58 @@ TF_Buffer* TF_GetAllOpList() { *(op_list.add_op()) = op; } TF_Buffer* ret = TF_NewBuffer(); - MessageToBuffer(op_list, ret); + TF_CHECK_OK(MessageToBuffer(op_list, ret)); return ret; } +// -------------------------------------------------------------------------- +// ListDevices & SessionListDevices API + +void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; } + +TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) { + TF_DeviceList* response = new TF_DeviceList; + status->status = session->session->ListDevices(&response->response); + return response; +} + +TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session, + TF_Status* status) { + TF_DeviceList* response = new TF_DeviceList; + status->status = session->session->ListDevices(&response->response); + return response; +} + +int TF_DeviceListCount(const TF_DeviceList* list) { + return list->response.size(); +} + +#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \ + return_type method_name(const TF_DeviceList* list, const int index, \ + TF_Status* status) { \ + if (list == nullptr) { \ + status->status = InvalidArgument("list is null!"); \ + return err_val; \ + } \ + if (index < 0 || index >= list->response.size()) { \ + status->status = InvalidArgument("index out of bounds"); \ + return err_val; \ + } \ + return list->response[index].accessor; \ + } + +TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr); +TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(), + nullptr); +TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1); + +#undef TF_DEVICELIST_METHOD + } // end extern "C" // -------------------------------------------------------------------------- // New Graph and Session API -// Structures ----------------------------------------------------------------- - -extern "C" { - -struct TF_Graph { - TF_Graph() - : graph(OpRegistry::Global()), - refiner(graph.op_registry()), - num_sessions(0), - delete_requested(false) {} - mutex mu; - Graph graph GUARDED_BY(mu); - - // Runs shape inference. - tensorflow::ShapeRefiner refiner GUARDED_BY(mu); - - // Maps from name of an operation to the Node* in 'graph'. - std::unordered_map name_map GUARDED_BY(mu); - - // TF_Graph may only / must be deleted when - // num_sessions == 0 && delete_requested == true - - // num_sessions incremented by TF_NewSession, and decremented by - // TF_DeleteSession. - int num_sessions GUARDED_BY(mu); - bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph -}; - -struct TF_OperationDescription { - TF_OperationDescription(TF_Graph* g, const char* op_type, - const char* node_name) - : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} - - NodeBuilder node_builder; - TF_Graph* graph; - std::vector colocation_constraints; -}; - -struct TF_Operation { - Node node; -}; - -struct TF_Session { - TF_Session(Session* s, TF_Graph* g) - : session(s), graph(g), last_num_graph_nodes(0) {} - Session* session; - TF_Graph* graph; - mutex mu; - int last_num_graph_nodes; -}; - -} // end extern "C" - // Helper functions ----------------------------------------------------------- namespace { @@ -785,15 +771,13 @@ TF_Operation* ToOperation(Node* node) { } tensorflow::string OutputName(const TF_Output& output) { - return tensorflow::strings::StrCat(output.oper->node.name(), ":", - output.index); + return StrCat(output.oper->node.name(), ":", output.index); } const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, const char* attr_name, TF_Status* status) { - const tensorflow::AttrValue* attr = - tensorflow::AttrSlice(oper->node.def()).Find(attr_name); + const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); if (attr == nullptr) { status->status = InvalidArgument("Operation has no attr named '", attr_name, "'."); @@ -821,6 +805,7 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, } std::vector dim_vec; + dim_vec.reserve(num_dims); for (int i = 0; i < num_dims; ++i) { dim_vec.push_back(ic->MakeDim(dims[i])); } @@ -899,10 +884,17 @@ void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims, extern "C" { +static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, + const char* op_type, + const char* oper_name) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + return new TF_OperationDescription(graph, op_type, oper_name); +} + TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type, const char* oper_name) { mutex_lock l(graph->mu); - return new TF_OperationDescription(graph, op_type, oper_name); + return TF_NewOperationLocked(graph, op_type, oper_name); } void TF_SetDevice(TF_OperationDescription* desc, const char* device) { @@ -928,8 +920,8 @@ void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { } void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { - desc->colocation_constraints.emplace_back(tensorflow::strings::StrCat( - tensorflow::kColocationGroupPrefix, op->node.name())); + desc->colocation_constraints.emplace_back( + StrCat(tensorflow::kColocationGroupPrefix, op->node.name())); } void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, @@ -1131,10 +1123,10 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, } } -TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, - TF_Status* status) { +static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, + TF_Status* status) + EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { Node* ret = nullptr; - mutex_lock l(desc->graph->mu); if (desc->graph->name_map.count(desc->node_builder.node_name())) { status->status = InvalidArgument("Duplicate node name in graph: '", @@ -1148,14 +1140,14 @@ TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, if (status->status.ok()) { // Run shape inference function for newly added node. - // - // TODO(b/28152992): Enable returning the result of this - // code-path once we have converted all python shape functions - // to call their C++ versions. - desc->graph->refiner.AddNode(ret); - + status->status = desc->graph->refiner.AddNode(ret); + } + if (status->status.ok()) { // Add the node to the name-to-node mapping. desc->graph->name_map[ret->name()] = ret; + } else if (ret != nullptr) { + desc->graph->graph.RemoveNode(ret); + ret = nullptr; } } @@ -1164,6 +1156,12 @@ TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, return ToOperation(ret); } +TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, + TF_Status* status) { + mutex_lock l(desc->graph->mu); + return TF_FinishOperationLocked(desc, status); +} + // TF_Operation functions // ---------------------------------------------------------- @@ -1176,7 +1174,7 @@ const char* TF_OperationOpType(TF_Operation* oper) { } const char* TF_OperationDevice(TF_Operation* oper) { - return oper->node.def().device().c_str(); + return oper->node.requested_device().c_str(); } int TF_OperationNumOutputs(TF_Operation* oper) { @@ -1191,8 +1189,8 @@ TF_DataType TF_OperationOutputType(TF_Output oper_out) { int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, TF_Status* status) { NameRangeMap name_ranges; - status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), - nullptr, &name_ranges); + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { @@ -1213,8 +1211,8 @@ TF_DataType TF_OperationInputType(TF_Input oper_in) { int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, TF_Status* status) { NameRangeMap name_ranges; - status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), - &name_ranges, nullptr); + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { @@ -1452,26 +1450,27 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, } } -#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ - void func(TF_Operation* oper, const char* attr_name, c_type* value, \ - TF_Status* status) { \ - cpp_type v; \ - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &v); \ - *value = static_cast(v); \ - } \ - void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ - int max_values, TF_Status* status) { \ - const auto* attr = GetAttrValue(oper, attr_name, status); \ - if (!status->status.ok()) return; \ - if (attr->value_case() != tensorflow::AttrValue::kList) { \ - status->status = \ - InvalidArgument("Value for '", attr_name, "' is not a list."); \ - return; \ - } \ - const auto len = std::min(max_values, attr->list().list_field##_size()); \ - for (int i = 0; i < len; ++i) { \ - values[i] = static_cast(attr->list().list_field(i)); \ - } \ +#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ + void func(TF_Operation* oper, const char* attr_name, c_type* value, \ + TF_Status* status) { \ + cpp_type v; \ + status->status = \ + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ + *value = static_cast(v); \ + } \ + void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ + int max_values, TF_Status* status) { \ + const auto* attr = GetAttrValue(oper, attr_name, status); \ + if (!status->status.ok()) return; \ + if (attr->value_case() != tensorflow::AttrValue::kList) { \ + status->status = \ + InvalidArgument("Value for '", attr_name, "' is not a list."); \ + return; \ + } \ + const auto len = std::min(max_values, attr->list().list_field##_size()); \ + for (int i = 0; i < len; ++i) { \ + values[i] = static_cast(attr->list().list_field(i)); \ + } \ } DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); @@ -1482,7 +1481,8 @@ DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, int64_t* value, int num_dims, TF_Status* status) { PartialTensorShape shape; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shape); + status->status = + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); if (!status->status.ok()) return; auto len = std::min(shape.dims(), num_dims); for (int i = 0; i < len; ++i) { @@ -1496,7 +1496,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, int storage_size, TF_Status* status) { std::vector shapes; status->status = - tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shapes); + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); if (!status->status.ok()) return; auto len = std::min(static_cast(shapes.size()), max_values); int64_t* p = storage; @@ -1563,7 +1563,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, TF_Tensor** value, TF_Status* status) { *value = nullptr; Tensor t; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &t); + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); if (!status->status.ok()) return; *value = new TF_Tensor{static_cast(t.dtype()), t.shape(), tensorflow::TensorCApi::Buffer(t)}; @@ -1574,7 +1574,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, TF_Tensor** values, int max_values, TF_Status* status) { std::vector ts; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &ts); + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); if (!status->status.ok()) return; const auto len = std::min(max_values, static_cast(ts.size())); for (int i = 0; i < len; ++i) { @@ -1653,10 +1653,6 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, status->status = MessageToBuffer(def, output_graph_def); } -struct TF_ImportGraphDefOptions { - tensorflow::ImportGraphDefOptions opts; -}; - TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { return new TF_ImportGraphDefOptions; } @@ -1682,6 +1678,12 @@ void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, opts->opts.input_map[TensorId(src_name, src_index)] = ToTensorId(dst); } +void TF_ImportGraphDefOptionsRemapControlDependency( + TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) { + opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] = + TensorId(dst->node.name(), tensorflow::Graph::kControlSlot); +} + extern void TF_ImportGraphDefOptionsAddControlDependency( TF_ImportGraphDefOptions* opts, TF_Operation* oper) { opts->opts.control_dependencies.push_back(oper->node.name()); @@ -1750,6 +1752,398 @@ void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, status); } +// While loop functions ------------------------------------------------------- + +namespace { +bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, + TF_Output* input, TF_Status* status) { + TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name); + TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input)); + // TODO(skyewm): set placeholder shape + TF_Operation* oper = TF_FinishOperation(desc, status); + if (!status->status.ok()) return false; + *input = {oper, 0}; + return true; +} + +bool CreateEnter(TF_Graph* g, const char* node_name, const char* frame_name, + const TF_Output& input, TF_Output* enter, TF_Status* status) + EXCLUSIVE_LOCKS_REQUIRED(g->mu) { + TF_OperationDescription* desc = TF_NewOperationLocked(g, "Enter", node_name); + TF_AddInput(desc, input); + TF_SetAttrString(desc, "frame_name", frame_name, strlen(frame_name)); + TF_Operation* oper = TF_FinishOperationLocked(desc, status); + if (!status->status.ok()) return false; + *enter = {oper, 0}; + return true; +} + +bool CreateMerge(TF_Graph* g, const char* name, const TF_Output& input, + const char* backedge_name, int backedge_index, + TF_Output* merge, TF_Status* status) + EXCLUSIVE_LOCKS_REQUIRED(g->mu) { + TF_OperationDescription* desc = TF_NewOperationLocked(g, "Merge", name); + + // The merge nodes accept the while loop's back edges as an input. Use the + // underlying NodeBuilder API directly to create an input to the + // not-yet-created back edge. + std::vector input_list; + input_list.push_back(NodeBuilder::NodeOut(&input.oper->node, input.index)); + // All merge inputs must have same type + DataType type = input.oper->node.output_type(input.index); + input_list.push_back( + NodeBuilder::NodeOut(backedge_name, backedge_index, type)); + + desc->node_builder.Input(input_list); + + TF_Operation* oper = TF_FinishOperationLocked(desc, status); + if (!status->status.ok()) return false; + *merge = {oper, 0}; + return true; +} + +bool CreateSwitch(TF_Graph* g, const char* name, const TF_Output& input, + const TF_Output& predicate, TF_Output* switch_true, + TF_Output* switch_false, TF_Status* status) + EXCLUSIVE_LOCKS_REQUIRED(g->mu) { + TF_OperationDescription* desc = TF_NewOperationLocked(g, "Switch", name); + TF_AddInput(desc, input); + TF_AddInput(desc, predicate); + TF_Operation* oper = TF_FinishOperationLocked(desc, status); + if (!status->status.ok()) return false; + *switch_false = {oper, 0}; + *switch_true = {oper, 1}; + return true; +} + +bool CreateNext(TF_Graph* g, const char* name, const TF_Output& input, + TF_Output* next, TF_Status* status) + EXCLUSIVE_LOCKS_REQUIRED(g->mu) { + TF_OperationDescription* desc = + TF_NewOperationLocked(g, "NextIteration", name); + TF_AddInput(desc, input); + TF_Operation* oper = TF_FinishOperationLocked(desc, status); + if (!status->status.ok()) return false; + *next = {oper, 0}; + return true; +} + +bool CreateExit(TF_Graph* g, const char* name, const TF_Output& input, + TF_Output* exit, TF_Status* status) + EXCLUSIVE_LOCKS_REQUIRED(g->mu) { + TF_OperationDescription* desc = TF_NewOperationLocked(g, "Exit", name); + TF_AddInput(desc, input); + TF_Operation* oper = TF_FinishOperationLocked(desc, status); + if (!status->status.ok()) return false; + *exit = {oper, 0}; + return true; +} + +class ScopedImportGraphDefOptions { + public: + ScopedImportGraphDefOptions() { opts_ = TF_NewImportGraphDefOptions(); } + ~ScopedImportGraphDefOptions() { TF_DeleteImportGraphDefOptions(opts_); } + + TF_ImportGraphDefOptions* get() const { return opts_; } + + private: + TF_ImportGraphDefOptions* opts_; + + TF_DISALLOW_COPY_AND_ASSIGN(ScopedImportGraphDefOptions); +}; + +// Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input +// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. +// `prefix` will be prepended to copied node names. `return_nodes` are nodes +// in `src_graph`, and the new corresponding nodes in `dst_graph` will be +// returned. `return_nodes` should be preallocated to size `nreturn_nodes`. +bool CopyGraph(TF_Graph* src_graph, TF_Graph* dst_graph, + const TF_Output* src_inputs, + const std::vector& dst_inputs, const char* prefix, + const TF_Output* nodes_to_return, int nreturn_nodes, + TF_Output* return_nodes, TF_Status* s) + EXCLUSIVE_LOCKS_REQUIRED(dst_graph->mu) { + GraphDef gdef; + src_graph->graph.ToGraphDef(&gdef); + + ScopedImportGraphDefOptions opts; + TF_ImportGraphDefOptionsSetPrefix(opts.get(), prefix); + + for (int i = 0; i < dst_inputs.size(); ++i) { + TensorId src = ToTensorId(src_inputs[i]); + TF_ImportGraphDefOptionsAddInputMapping(opts.get(), src.first.data(), + src.second, dst_inputs[i]); + } + // We use the pivot node to control constants in `src_graph` + TF_Operation* pivot = dst_inputs[0].oper; + TF_ImportGraphDefOptionsAddControlDependency(opts.get(), pivot); + + for (int i = 0; i < nreturn_nodes; ++i) { + TF_ImportGraphDefOptionsAddReturnOutput( + opts.get(), nodes_to_return[i].oper->node.name().c_str(), + nodes_to_return[i].index); + } + + GraphImportGraphDefLocked(dst_graph, gdef, opts.get(), return_nodes, + nreturn_nodes, s); + if (TF_GetCode(s) != TF_OK) return false; + return true; +} + +bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) { + if (params.cond_graph == nullptr || params.body_graph == nullptr || + params.cond_graph->parent == nullptr || + params.cond_graph->parent != params.body_graph->parent || + params.cond_graph->parent_inputs != params.body_graph->parent_inputs || + params.ninputs <= 0 || params.cond_inputs == nullptr || + params.body_inputs == nullptr || params.body_outputs == nullptr) { + s->status = InvalidArgument( + "TF_WhileParams must be created by successful TF_NewWhile() call"); + return false; + } + return true; +} + +bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) { + if (params.cond_output.oper == nullptr) { + s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set"); + return false; + } + for (int i = 0; i < params.ninputs; ++i) { + if (params.body_outputs[i].oper == nullptr) { + s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ", + "field isn't set"); + return false; + } + } + if (params.name == nullptr) { + s->status = InvalidArgument("TF_WhileParams `name` field is null"); + return false; + } + return true; +} + +void FreeWhileResources(const TF_WhileParams* params) { + TF_DeleteGraph(params->cond_graph); + TF_DeleteGraph(params->body_graph); + delete[] params->cond_inputs; + delete[] params->body_inputs; + delete[] params->body_outputs; +} + +TF_WhileParams EmptyWhileParams() { + return {0, nullptr, nullptr, {nullptr, 0}, + nullptr, nullptr, nullptr, nullptr}; +} + +} // namespace + +TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, + TF_Status* status) { + if (ninputs == 0) { + status->status = + InvalidArgument("TF_NewWhile() must be passed at least one input"); + return EmptyWhileParams(); + } + + TF_Graph* cond_graph = TF_NewGraph(); + TF_Graph* body_graph = TF_NewGraph(); + cond_graph->parent = g; + cond_graph->parent_inputs = inputs; + body_graph->parent = g; + body_graph->parent_inputs = inputs; + + TF_Output* cond_inputs = new TF_Output[ninputs]; + TF_Output cond_output = {nullptr, -1}; + TF_Output* body_inputs = new TF_Output[ninputs]; + TF_Output* body_outputs = new TF_Output[ninputs]; + for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1}; + const char* name = nullptr; + + for (int i = 0; i < ninputs; ++i) { + // TODO(skyewm): prefix names with underscore (requires some plumbing) + if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(), + &cond_inputs[i], status)) { + break; + } + if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(), + &body_inputs[i], status)) { + break; + } + } + + TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output, + body_graph, body_inputs, body_outputs, name}; + + if (!status->status.ok()) { + FreeWhileResources(¶ms); + return EmptyWhileParams(); + } + return params; +} + +namespace { + +// TODO(skyewm): make nodes in while loop unfetchable like in Python version +void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status, + TF_Output* outputs) { + if (!ValidateInputWhileParams(*params, status)) return; + + TF_Graph* parent = params->cond_graph->parent; + TF_Output* parent_inputs = params->cond_graph->parent_inputs; + int n = params->ninputs; + + mutex_lock l(parent->mu); + + // Create Enter nodes + std::vector enter_nodes(n); + for (int i = 0; i < n; ++i) { + if (!CreateEnter(parent, StrCat(params->name, "/enter", i).c_str(), + params->name, parent_inputs[i], &enter_nodes[i], status)) { + return; + } + } + + // Create Merge nodes + std::vector merge_nodes(n); + for (int i = 0; i < n; ++i) { + if (!CreateMerge(parent, StrCat(params->name, "/merge", i).c_str(), + enter_nodes[i], StrCat(params->name, "/next", i).c_str(), + 0, &merge_nodes[i], status)) { + return; + } + } + + // Copy cond_graph to parent and replace input placeholders with merge node + // outputs, and get handle to new cond output + tensorflow::string cond_prefix = StrCat(params->name, "/cond"); + TF_Output cond_output; + if (!CopyGraph(params->cond_graph, parent, params->cond_inputs, merge_nodes, + cond_prefix.c_str(), ¶ms->cond_output, 1, &cond_output, + status)) { + return; + } + + // Create Switch nodes + std::vector switch_trues(n); + std::vector switch_falses(n); + for (int i = 0; i < n; ++i) { + if (!CreateSwitch(parent, StrCat(params->name, "/switch", i).c_str(), + merge_nodes[i], cond_output, &switch_trues[i], + &switch_falses[i], status)) { + return; + } + } + + // Copy body_graph to parent, replace input placeholders with switch node + // true outputs, and get handles to new body outputs + tensorflow::string body_prefix = StrCat(params->name, "/body"); + std::vector body_outputs(n); + if (!CopyGraph(params->body_graph, parent, params->body_inputs, switch_trues, + body_prefix.c_str(), params->body_outputs, n, + body_outputs.data(), status)) { + return; + } + + // Create Next nodes + std::vector next_nodes(n); + for (int i = 0; i < n; ++i) { + if (!CreateNext(parent, StrCat(params->name, "/next", i).c_str(), + body_outputs[i], &next_nodes[i], status)) { + return; + } + } + + // Create Exit nodes (which are the outputs of the while loop) + for (int i = 0; i < n; ++i) { + if (!CreateExit(parent, StrCat(params->name, "/exit", i).c_str(), + switch_falses[i], &outputs[i], status)) { + return; + } + } +} + +} // namespace + +void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, + TF_Output* outputs) { + // If it appears the caller created or modified `params`, don't free resources + if (!ValidateConstWhileParams(*params, status)) return; + TF_FinishWhileHelper(params, status, outputs); + FreeWhileResources(params); +} + +void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } + +#ifndef __ANDROID__ +namespace { + +void OutputsFromTFOutputs(TF_Output* tf_outputs, int n, TF_Status* status, + std::vector* outputs) { + outputs->resize(n); + for (int i = 0; i < n; i++) { + const TF_Output& tf_output = tf_outputs[i]; + (*outputs)[i] = tensorflow::Output(&tf_output.oper->node, tf_output.index); + } +} + +void TFOutputsFromOutputs(const std::vector& outputs, + TF_Output* tf_outputs) { + for (int i = 0; i < outputs.size(); i++) { + tf_outputs[i].oper = ToOperation(outputs[i].node()); + tf_outputs[i].index = outputs[i].index(); + } +} + +} // namespace +#endif // __ANDROID__ + +void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, TF_Output* dy) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Adding gradients is not supported in Android. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); +#else + std::vector y_arg; + std::vector x_arg; + std::vector dy_arg; + OutputsFromTFOutputs(y, ny, status, &y_arg); + OutputsFromTFOutputs(x, nx, status, &x_arg); + + { + // We need to hold on to the lock while we have a scope that uses TF_Graph. + mutex_lock graph_lock(g->mu); + + const int max_node_id_before = g->graph.num_node_ids(); + + tensorflow::Scope scope = + NewInternalScope(&g->graph, &status->status, &g->refiner); + + if (dx != nullptr) { + std::vector dx_arg; + OutputsFromTFOutputs(dx, ny, status, &dx_arg); + status->status = + AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg); + } else { + status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg); + } + + // Update g->name_map with the name_map from the scope, which will contain + // the new gradient ops. + for (int i = max_node_id_before; i < g->graph.num_node_ids(); ++i) { + Node* n = g->graph.FindNodeId(i); + if (n == nullptr) continue; + g->name_map[n->name()] = n; + } + } + + // Unpack the results from grad_outputs_arg. + TFOutputsFromOutputs(dy_arg, dy); +#endif // __ANDROID__ +} + // TF_Session functions ---------------------------------------------- TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, @@ -1764,15 +2158,23 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, return new TF_Session(session, graph); } else { DCHECK_EQ(nullptr, session); - return NULL; + return nullptr; } } -#ifndef __ANDROID__ TF_Session* TF_LoadSessionFromSavedModel( const TF_SessionOptions* session_options, const TF_Buffer* run_options, const char* export_dir, const char* const* tags, int tags_len, TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) { +// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that +// the tensorflow/cc/saved_model:loader build target is Android friendly. +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Loading a SavedModel is not supported in Android. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); + return nullptr; +#else mutex_lock l(graph->mu); if (!graph->name_map.empty()) { @@ -1781,9 +2183,8 @@ TF_Session* TF_LoadSessionFromSavedModel( } RunOptions run_options_proto; - if (run_options != nullptr && - !run_options_proto.ParseFromArray(run_options->data, - run_options->length)) { + if (run_options != nullptr && !run_options_proto.ParseFromArray( + run_options->data, run_options->length)) { status->status = InvalidArgument("Unparseable RunOptions proto"); return nullptr; } @@ -1821,8 +2222,8 @@ TF_Session* TF_LoadSessionFromSavedModel( graph->num_sessions += 1; session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; -} #endif // __ANDROID__ +} void TF_CloseSession(TF_Session* s, TF_Status* status) { status->status = s->session->Close(); @@ -1853,7 +2254,7 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { GraphDef graph_def; - graph_def.mutable_versions()->CopyFrom(graph.versions()); + *graph_def.mutable_versions() = graph.versions(); // Fill graph_def with nodes with ids in the range // [session->last_num_graph_nodes, num_nodes), that is the nodes // added since the last TF_SessionRun() call. @@ -1954,6 +2355,11 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, } } +void TF_DeletePRunHandle(const char* handle) { + delete[] handle; + // TODO(suharshs): Free up any resources held by the partial run state. +} + void TF_SessionPRun(TF_Session* session, const char* handle, const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, const TF_Output* outputs, diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 8d0f398d4a5..15139a47acf 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -64,6 +64,25 @@ limitations under the License. // and the API just provides high level controls over the number of // devices of each type. +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes.$a +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(COMPILER_MSVC) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // COMPILER_MSVC +#endif // SWIG + #ifdef __cplusplus extern "C" { #endif @@ -71,12 +90,12 @@ extern "C" { // -------------------------------------------------------------------------- // TF_Version returns a string describing version information of the // TensorFlow library. TensorFlow using semantic versioning. -extern const char* TF_Version(); +TF_CAPI_EXPORT extern const char* TF_Version(); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. // The enum values here are identical to corresponding values in types.proto. -typedef enum { +typedef enum TF_DataType { TF_FLOAT = 1, TF_DOUBLE = 2, TF_INT32 = 3, // Int32 tensors are always in 'host' memory. @@ -103,12 +122,12 @@ typedef enum { // TF_DataTypeSize returns the sizeof() for the underlying type corresponding // to the given TF_DataType enum value. Returns 0 for variable length types // (eg. TF_STRING) or on failure. -extern size_t TF_DataTypeSize(TF_DataType dt); +TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt); // -------------------------------------------------------------------------- // TF_Code holds an error code. The enum values here are identical to // corresponding values in error_codes.proto. -typedef enum { +typedef enum TF_Code { TF_OK = 0, TF_CANCELLED = 1, TF_UNKNOWN = 2, @@ -134,23 +153,24 @@ typedef enum { typedef struct TF_Status TF_Status; // Return a new status object. -extern TF_Status* TF_NewStatus(); +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(); // Delete a previously created status object. -extern void TF_DeleteStatus(TF_Status*); +TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); // Record in *s. Any previous information is lost. // A common use is to clear a status: TF_SetStatus(s, TF_OK, ""); -extern void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg); +TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code, + const char* msg); // Return the code record in *s. -extern TF_Code TF_GetCode(const TF_Status* s); +TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s); // Return a pointer to the (null-terminated) error message in *s. The // return value points to memory that is only usable until the next // mutation to *s. Always returns an empty string if TF_GetCode(s) is // TF_OK. -extern const char* TF_Message(const TF_Status* s); +TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s); // -------------------------------------------------------------------------- // TF_Buffer holds a pointer to a block of data and its associated length. @@ -168,14 +188,15 @@ typedef struct TF_Buffer { // Makes a copy of the input and sets an appropriate deallocator. Useful for // passing in read-only, input protobufs. -extern TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, + size_t proto_len); // Useful for passing *out* a protobuf. -extern TF_Buffer* TF_NewBuffer(); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(); -extern void TF_DeleteBuffer(TF_Buffer*); +TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); -extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); +TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); // -------------------------------------------------------------------------- // TF_Tensor holds a multi-dimensional array of elements of a single data type. @@ -202,11 +223,10 @@ typedef struct TF_Tensor TF_Tensor; // (*deallocator)(data, len, deallocator_arg) // Clients must provide a custom deallocator function so they can pass in // memory managed by something like numpy. -extern TF_Tensor* TF_NewTensor(TF_DataType, const int64_t* dims, int num_dims, - void* data, size_t len, - void (*deallocator)(void* data, size_t len, - void* arg), - void* deallocator_arg); +TF_CAPI_EXPORT extern TF_Tensor* TF_NewTensor( + TF_DataType, const int64_t* dims, int num_dims, void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg); // Allocate and return a new Tensor. // @@ -217,27 +237,32 @@ extern TF_Tensor* TF_NewTensor(TF_DataType, const int64_t* dims, int num_dims, // // The caller must set the Tensor values by writing them to the pointer returned // by TF_TensorData with length TF_TensorByteSize. -extern TF_Tensor* TF_AllocateTensor(TF_DataType, const int64_t* dims, - int num_dims, size_t len); +TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTensor(TF_DataType, + const int64_t* dims, + int num_dims, size_t len); + +// Deletes `tensor` and returns a new TF_Tensor with the same content if +// possible. Returns nullptr and leaves `tensor` untouched if not. +TF_CAPI_EXPORT extern TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor); // Destroy a tensor. -extern void TF_DeleteTensor(TF_Tensor*); +TF_CAPI_EXPORT extern void TF_DeleteTensor(TF_Tensor*); // Return the type of a tensor element. -extern TF_DataType TF_TensorType(const TF_Tensor*); +TF_CAPI_EXPORT extern TF_DataType TF_TensorType(const TF_Tensor*); // Return the number of dimensions that the tensor has. -extern int TF_NumDims(const TF_Tensor*); +TF_CAPI_EXPORT extern int TF_NumDims(const TF_Tensor*); // Return the length of the tensor in the "dim_index" dimension. // REQUIRES: 0 <= dim_index < TF_NumDims(tensor) -extern int64_t TF_Dim(const TF_Tensor* tensor, int dim_index); +TF_CAPI_EXPORT extern int64_t TF_Dim(const TF_Tensor* tensor, int dim_index); // Return the size of the underlying data in bytes. -extern size_t TF_TensorByteSize(const TF_Tensor*); +TF_CAPI_EXPORT extern size_t TF_TensorByteSize(const TF_Tensor*); // Return a pointer to the underlying data buffer. -extern void* TF_TensorData(const TF_Tensor*); +TF_CAPI_EXPORT extern void* TF_TensorData(const TF_Tensor*); // -------------------------------------------------------------------------- // Encode the string `src` (`src_len` bytes long) into `dst` in the format @@ -247,8 +272,9 @@ extern void* TF_TensorData(const TF_Tensor*); // // On success returns the size in bytes of the encoded string. // Returns an error into `status` otherwise. -extern size_t TF_StringEncode(const char* src, size_t src_len, char* dst, - size_t dst_len, TF_Status* status); +TF_CAPI_EXPORT extern size_t TF_StringEncode(const char* src, size_t src_len, + char* dst, size_t dst_len, + TF_Status* status); // Decode a string encoded using TF_StringEncode. // @@ -258,19 +284,20 @@ extern size_t TF_StringEncode(const char* src, size_t src_len, char* dst, // `*dst` and `*dst_len` are undefined and an error is set in `status`. // // Does not read memory more than `src_len` bytes beyond `src`. -extern size_t TF_StringDecode(const char* src, size_t src_len, const char** dst, - size_t* dst_len, TF_Status* status); +TF_CAPI_EXPORT extern size_t TF_StringDecode(const char* src, size_t src_len, + const char** dst, size_t* dst_len, + TF_Status* status); // Return the size in bytes required to encode a string `len` bytes long into a // TF_STRING tensor. -extern size_t TF_StringEncodedSize(size_t len); +TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len); // -------------------------------------------------------------------------- // TF_SessionOptions holds options that can be passed during session creation. typedef struct TF_SessionOptions TF_SessionOptions; // Return a new options object. -extern TF_SessionOptions* TF_NewSessionOptions(); +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(); // Set the target in TF_SessionOptions.options. // target can be empty, a single entry, or a comma separated list of entries. @@ -278,17 +305,19 @@ extern TF_SessionOptions* TF_NewSessionOptions(); // "local" // ip:port // host:port -extern void TF_SetTarget(TF_SessionOptions* options, const char* target); +TF_CAPI_EXPORT extern void TF_SetTarget(TF_SessionOptions* options, + const char* target); // Set the config in TF_SessionOptions.options. // config should be a serialized tensorflow.ConfigProto proto. // If config was not parsed successfully as a ConfigProto, record the // error information in *status. -extern void TF_SetConfig(TF_SessionOptions* options, const void* proto, - size_t proto_len, TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetConfig(TF_SessionOptions* options, + const void* proto, size_t proto_len, + TF_Status* status); // Destroy an options object. -extern void TF_DeleteSessionOptions(TF_SessionOptions*); +TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); // TODO(jeff,sanjay): // - export functions to set Config fields @@ -301,11 +330,11 @@ extern void TF_DeleteSessionOptions(TF_SessionOptions*); typedef struct TF_Graph TF_Graph; // Return a new graph object. -extern TF_Graph* TF_NewGraph(); +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(); // Destroy an options object. Graph will be deleted once no more // TFSession's are referencing it. -extern void TF_DeleteGraph(TF_Graph*); +TF_CAPI_EXPORT extern void TF_DeleteGraph(TF_Graph*); // Operation being built. The underlying graph must outlive this. typedef struct TF_OperationDescription TF_OperationDescription; @@ -343,9 +372,11 @@ typedef struct TF_Output { // * `output` is not in `graph`. // * An invalid shape is being set (e.g., the shape being set // is incompatible with the existing shape). -extern void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, - const int64_t* dims, const int num_dims, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_GraphSetTensorShape(TF_Graph* graph, + TF_Output output, + const int64_t* dims, + const int num_dims, + TF_Status* status); // Returns the number of dimensions of the Tensor referenced by `output` // in `graph`. @@ -354,8 +385,9 @@ extern void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, // // Returns an error into `status` if: // * `output` is not in `graph`. -extern int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, - TF_Status* status); +TF_CAPI_EXPORT extern int TF_GraphGetTensorNumDims(TF_Graph* graph, + TF_Output output, + TF_Status* status); // Returns the shape of the Tensor referenced by `output` in `graph` // into `dims`. `dims` must be an array large enough to hold `num_dims` @@ -369,20 +401,21 @@ extern int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, // Returns an error into `status` if: // * `output` is not in `graph`. // * `num_dims` does not match the actual number of dimensions. -extern void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, - int64_t* dims, int num_dims, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph, + TF_Output output, + int64_t* dims, int num_dims, + TF_Status* status); // Operation will only be added to *graph when TF_FinishOperation() is // called (assuming TF_FinishOperation() does not return an error). // *graph must not be deleted until after TF_FinishOperation() is // called. -extern TF_OperationDescription* TF_NewOperation(TF_Graph* graph, - const char* op_type, - const char* oper_name); +TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperation( + TF_Graph* graph, const char* op_type, const char* oper_name); // Specify the device for `desc`. Defaults to empty, meaning unconstrained. -extern void TF_SetDevice(TF_OperationDescription* desc, const char* device); +TF_CAPI_EXPORT extern void TF_SetDevice(TF_OperationDescription* desc, + const char* device); // The calls to TF_AddInput and TF_AddInputList must match (in number, // order, and type) the op declaration. For example, the "Concat" op @@ -405,101 +438,115 @@ extern void TF_SetDevice(TF_OperationDescription* desc, const char* device); // TF_AddInputList(desc, values_inputs, 5); // For inputs that take a single tensor. -extern void TF_AddInput(TF_OperationDescription* desc, TF_Output input); +TF_CAPI_EXPORT extern void TF_AddInput(TF_OperationDescription* desc, + TF_Output input); // For inputs that take a list of tensors. // inputs must point to TF_Output[num_inputs]. -extern void TF_AddInputList(TF_OperationDescription* desc, - const TF_Output* inputs, int num_inputs); +TF_CAPI_EXPORT extern void TF_AddInputList(TF_OperationDescription* desc, + const TF_Output* inputs, + int num_inputs); // Call once per control input to `desc`. -extern void TF_AddControlInput(TF_OperationDescription* desc, - TF_Operation* input); +TF_CAPI_EXPORT extern void TF_AddControlInput(TF_OperationDescription* desc, + TF_Operation* input); // Request that `desc` be co-located on the device where `op` // is placed. // // Use of this is discouraged since the implementation of device placement is // subject to change. Primarily intended for internal libraries -extern void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op); +TF_CAPI_EXPORT extern void TF_ColocateWith(TF_OperationDescription* desc, + TF_Operation* op); // Call some TF_SetAttr*() function for every attr that is not // inferred from an input and doesn't have a default value you wish to // keep. // `value` must point to a string of length `length` bytes. -extern void TF_SetAttrString(TF_OperationDescription* desc, - const char* attr_name, const void* value, - size_t length); +TF_CAPI_EXPORT extern void TF_SetAttrString(TF_OperationDescription* desc, + const char* attr_name, + const void* value, size_t length); // `values` and `lengths` each must have lengths `num_values`. // `values[i]` must point to a string of length `lengths[i]` bytes. -extern void TF_SetAttrStringList(TF_OperationDescription* desc, - const char* attr_name, - const void* const* values, - const size_t* lengths, int num_values); -extern void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, - int64_t value); -extern void TF_SetAttrIntList(TF_OperationDescription* desc, - const char* attr_name, const int64_t* values, - int num_values); -extern void TF_SetAttrFloat(TF_OperationDescription* desc, - const char* attr_name, float value); -extern void TF_SetAttrFloatList(TF_OperationDescription* desc, - const char* attr_name, const float* values, - int num_values); -extern void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name, - unsigned char value); -extern void TF_SetAttrBoolList(TF_OperationDescription* desc, - const char* attr_name, - const unsigned char* values, int num_values); -extern void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name, - TF_DataType value); -extern void TF_SetAttrTypeList(TF_OperationDescription* desc, - const char* attr_name, const TF_DataType* values, - int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrStringList(TF_OperationDescription* desc, + const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrInt(TF_OperationDescription* desc, + const char* attr_name, int64_t value); +TF_CAPI_EXPORT extern void TF_SetAttrIntList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrFloat(TF_OperationDescription* desc, + const char* attr_name, float value); +TF_CAPI_EXPORT extern void TF_SetAttrFloatList(TF_OperationDescription* desc, + const char* attr_name, + const float* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrBool(TF_OperationDescription* desc, + const char* attr_name, + unsigned char value); +TF_CAPI_EXPORT extern void TF_SetAttrBoolList(TF_OperationDescription* desc, + const char* attr_name, + const unsigned char* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrType(TF_OperationDescription* desc, + const char* attr_name, + TF_DataType value); +TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, + const char* attr_name, + const TF_DataType* values, + int num_values); // Set `num_dims` to -1 to represent "unknown rank". Otherwise, // `dims` points to an array of length `num_dims`. `dims[i]` must be // >= -1, with -1 meaning "unknown dimension". -extern void TF_SetAttrShape(TF_OperationDescription* desc, - const char* attr_name, const int64_t* dims, - int num_dims); +TF_CAPI_EXPORT extern void TF_SetAttrShape(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* dims, int num_dims); // `dims` and `num_dims` must point to arrays of length `num_shapes`. // Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise, // `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]` // must be >= -1, with -1 meaning "unknown dimension". -extern void TF_SetAttrShapeList(TF_OperationDescription* desc, - const char* attr_name, - const int64_t* const* dims, const int* num_dims, - int num_shapes); +TF_CAPI_EXPORT extern void TF_SetAttrShapeList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* const* dims, + const int* num_dims, + int num_shapes); // `proto` must point to an array of `proto_len` bytes representing a // binary-serialized TensorShapeProto. -extern void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, - const char* attr_name, const void* proto, - size_t proto_len, TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProto( + TF_OperationDescription* desc, const char* attr_name, const void* proto, + size_t proto_len, TF_Status* status); // `protos` and `proto_lens` must point to arrays of length `num_shapes`. // `protos[i]` must point to an array of `proto_lens[i]` bytes // representing a binary-serialized TensorShapeProto. -extern void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, - const char* attr_name, - const void* const* protos, - const size_t* proto_lens, - int num_shapes, TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProtoList( + TF_OperationDescription* desc, const char* attr_name, + const void* const* protos, const size_t* proto_lens, int num_shapes, + TF_Status* status); -extern void TF_SetAttrTensor(TF_OperationDescription* desc, - const char* attr_name, TF_Tensor* value, - TF_Status* status); -extern void TF_SetAttrTensorList(TF_OperationDescription* desc, - const char* attr_name, - TF_Tensor* const* values, int num_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensor(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* value, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensorList(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* const* values, + int num_values, + TF_Status* status); // `proto` should point to a sequence of bytes of length `proto_len` // representing a binary serialization of an AttrValue protocol // buffer. -extern void TF_SetAttrValueProto(TF_OperationDescription* desc, - const char* attr_name, const void* proto, - size_t proto_len, TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); // If this function succeeds: // * *status is set to an OK value, @@ -511,37 +558,38 @@ extern void TF_SetAttrValueProto(TF_OperationDescription* desc, // * the graph is not modified, // * a null value is returned. // In either case, it deletes `desc`. -extern TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, - TF_Status* status); +TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperation( + TF_OperationDescription* desc, TF_Status* status); // TF_Operation functions. Operations are immutable once created, so // these are all query functions. -extern const char* TF_OperationName(TF_Operation* oper); -extern const char* TF_OperationOpType(TF_Operation* oper); -extern const char* TF_OperationDevice(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationName(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationOpType(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationDevice(TF_Operation* oper); -extern int TF_OperationNumOutputs(TF_Operation* oper); -extern TF_DataType TF_OperationOutputType(TF_Output oper_out); -extern int TF_OperationOutputListLength(TF_Operation* oper, - const char* arg_name, - TF_Status* status); +TF_CAPI_EXPORT extern int TF_OperationNumOutputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationOutputType(TF_Output oper_out); +TF_CAPI_EXPORT extern int TF_OperationOutputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); -extern int TF_OperationNumInputs(TF_Operation* oper); -extern TF_DataType TF_OperationInputType(TF_Input oper_in); -extern int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, - TF_Status* status); +TF_CAPI_EXPORT extern int TF_OperationNumInputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationInputType(TF_Input oper_in); +TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); // In this code: // TF_Output producer = TF_OperationInput(consumer); // There is an edge from producer.oper's output (given by // producer.index) to consumer.oper's input (given by consumer.index). -extern TF_Output TF_OperationInput(TF_Input oper_in); +TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in); // Get the number of current consumers of a specific output of an // operation. Note that this number can change when new operations // are added to the graph. -extern int TF_OperationOutputNumConsumers(TF_Output oper_out); +TF_CAPI_EXPORT extern int TF_OperationOutputNumConsumers(TF_Output oper_out); // Get list of all current consumers of a specific output of an // operation. `consumers` must point to an array of length at least @@ -550,24 +598,24 @@ extern int TF_OperationOutputNumConsumers(TF_Output oper_out); // modification of the graph can increase the number of consumers of // an operation. Returns the number of output consumers (should match // TF_OperationOutputNumConsumers(oper_out)). -extern int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, - int max_consumers); +TF_CAPI_EXPORT extern int TF_OperationOutputConsumers(TF_Output oper_out, + TF_Input* consumers, + int max_consumers); // Get the number of control inputs to an operation. -extern int TF_OperationNumControlInputs(TF_Operation* oper); +TF_CAPI_EXPORT extern int TF_OperationNumControlInputs(TF_Operation* oper); // Get list of all control inputs to an operation. `control_inputs` must // point to an array of length `max_control_inputs` (ideally set to // TF_OperationNumControlInputs(oper)). Returns the number of control // inputs (should match TF_OperationNumControlInputs(oper)). -extern int TF_OperationGetControlInputs(TF_Operation* oper, - TF_Operation** control_inputs, - int max_control_inputs); +TF_CAPI_EXPORT extern int TF_OperationGetControlInputs( + TF_Operation* oper, TF_Operation** control_inputs, int max_control_inputs); // Get the number of operations that have `*oper` as a control input. // Note that this number can change when new operations are added to // the graph. -extern int TF_OperationNumControlOutputs(TF_Operation* oper); +TF_CAPI_EXPORT extern int TF_OperationNumControlOutputs(TF_Operation* oper); // Get the list of operations that have `*oper` as a control input. // `control_outputs` must point to an array of length at least @@ -576,12 +624,12 @@ extern int TF_OperationNumControlOutputs(TF_Operation* oper); // modification of the graph can increase the number of control // outputs. Returns the number of control outputs (should match // TF_OperationNumControlOutputs(oper)). -extern int TF_OperationGetControlOutputs(TF_Operation* oper, - TF_Operation** control_outputs, - int max_control_outputs); +TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( + TF_Operation* oper, TF_Operation** control_outputs, + int max_control_outputs); // TF_AttrType describes the type of the value of an attribute on an operation. -typedef enum { +typedef enum TF_AttrType { TF_ATTR_STRING = 0, TF_ATTR_INT = 1, TF_ATTR_FLOAT = 2, @@ -625,17 +673,18 @@ typedef struct TF_AttrMetadata { } TF_AttrMetadata; // Returns metadata about the value of the attribute `attr_name` of `oper`. -extern TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, - const char* attr_name, - TF_Status* status); +TF_CAPI_EXPORT extern TF_AttrMetadata TF_OperationGetAttrMetadata( + TF_Operation* oper, const char* attr_name, TF_Status* status); // Fills in `value` with the value of the attribute `attr_name`. `value` must // point to an array of length at least `max_length` (ideally set to // TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, - void* value, size_t max_length, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrString(TF_Operation* oper, + const char* attr_name, + void* value, + size_t max_length, + TF_Status* status); // Get the list of strings in the value of the attribute `attr_name`. Fills in // `values` and `lengths`, each of which must point to an array of length at @@ -648,64 +697,78 @@ extern void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, // attr_name). // // Fails if storage_size is too small to hold the requested number of strings. -extern void TF_OperationGetAttrStringList(TF_Operation* oper, - const char* attr_name, void** values, - size_t* lengths, int max_values, - void* storage, size_t storage_size, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrStringList( + TF_Operation* oper, const char* attr_name, void** values, size_t* lengths, + int max_values, void* storage, size_t storage_size, TF_Status* status); -extern void TF_OperationGetAttrInt(TF_Operation* oper, const char* attr_name, - int64_t* value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrInt(TF_Operation* oper, + const char* attr_name, + int64_t* value, + TF_Status* status); // Fills in `values` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `max_values` (ideally set // TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrIntList(TF_Operation* oper, - const char* attr_name, int64_t* values, - int max_values, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrIntList(TF_Operation* oper, + const char* attr_name, + int64_t* values, + int max_values, + TF_Status* status); -extern void TF_OperationGetAttrFloat(TF_Operation* oper, const char* attr_name, - float* value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloat(TF_Operation* oper, + const char* attr_name, + float* value, + TF_Status* status); // Fills in `values` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `max_values` (ideally set // to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrFloatList(TF_Operation* oper, - const char* attr_name, float* values, - int max_values, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloatList(TF_Operation* oper, + const char* attr_name, + float* values, + int max_values, + TF_Status* status); -extern void TF_OperationGetAttrBool(TF_Operation* oper, const char* attr_name, - unsigned char* value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrBool(TF_Operation* oper, + const char* attr_name, + unsigned char* value, + TF_Status* status); // Fills in `values` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `max_values` (ideally set // to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrBoolList(TF_Operation* oper, - const char* attr_name, - unsigned char* values, int max_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrBoolList(TF_Operation* oper, + const char* attr_name, + unsigned char* values, + int max_values, + TF_Status* status); -extern void TF_OperationGetAttrType(TF_Operation* oper, const char* attr_name, - TF_DataType* value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrType(TF_Operation* oper, + const char* attr_name, + TF_DataType* value, + TF_Status* status); // Fills in `values` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `max_values` (ideally set // to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrTypeList(TF_Operation* oper, - const char* attr_name, - TF_DataType* values, int max_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTypeList(TF_Operation* oper, + const char* attr_name, + TF_DataType* values, + int max_values, + TF_Status* status); // Fills in `value` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `num_dims` (ideally set to // TF_Attr_Meta.size from TF_OperationGetAttrMetadata(oper, attr_name)). -extern void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, - int64_t* value, int num_dims, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrShape(TF_Operation* oper, + const char* attr_name, + int64_t* value, + int num_dims, + TF_Status* status); // Fills in `dims` with the list of shapes in the attribute `attr_name` of // `oper` and `num_dims` with the corresponding number of dimensions. On return, @@ -720,35 +783,32 @@ extern void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, // attr_name). // // Fails if storage_size is insufficient to hold the requested shapes. -extern void TF_OperationGetAttrShapeList(TF_Operation* oper, - const char* attr_name, int64_t** dims, - int* num_dims, int num_shapes, - int64_t* storage, int storage_size, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrShapeList( + TF_Operation* oper, const char* attr_name, int64_t** dims, int* num_dims, + int num_shapes, int64_t* storage, int storage_size, TF_Status* status); // Sets `value` to the binary-serialized TensorShapeProto of the value of // `attr_name` attribute of `oper`'. -extern void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, - const char* attr_name, - TF_Buffer* value, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* value, + TF_Status* status); // Fills in `values` with binary-serialized TensorShapeProto values of the // attribute `attr_name` of `oper`. `values` must point to an array of length at // least `num_values` (ideally set to TF_AttrMetadata.list_size from // TF_OperationGetAttrMetadata(oper, attr_name)). -extern void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, - const char* attr_name, - TF_Buffer** values, - int max_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProtoList( + TF_Operation* oper, const char* attr_name, TF_Buffer** values, + int max_values, TF_Status* status); // Gets the TF_Tensor valued attribute of `attr_name` of `oper`. // // Allocates a new TF_Tensor which the caller is expected to take // ownership of (and can deallocate using TF_DeleteTensor). -extern void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, - TF_Tensor** value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensor(TF_Operation* oper, + const char* attr_name, + TF_Tensor** value, + TF_Status* status); // Fills in `values` with the TF_Tensor values of the attribute `attr_name` of // `oper`. `values` must point to an array of TF_Tensor* of length at least @@ -757,22 +817,22 @@ extern void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, // // The caller takes ownership of all the non-null TF_Tensor* entries in `values` // (which can be deleted using TF_DeleteTensor(values[i])). -extern void TF_OperationGetAttrTensorList(TF_Operation* oper, - const char* attr_name, - TF_Tensor** values, int max_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorList(TF_Operation* oper, + const char* attr_name, + TF_Tensor** values, + int max_values, + TF_Status* status); // Sets `output_attr_value` to the binary-serialized AttrValue proto // representation of the value of the `attr_name` attr of `oper`. -extern void TF_OperationGetAttrValueProto(TF_Operation* oper, - const char* attr_name, - TF_Buffer* output_attr_value, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrValueProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, + TF_Status* status); // Returns the operation in the graph with `oper_name`. Returns nullptr if // no operation found. -extern TF_Operation* TF_GraphOperationByName(TF_Graph* graph, - const char* oper_name); +TF_CAPI_EXPORT extern TF_Operation* TF_GraphOperationByName( + TF_Graph* graph, const char* oper_name); // Iterate through the operations of a graph. To use: // size_t pos = 0; @@ -780,48 +840,60 @@ extern TF_Operation* TF_GraphOperationByName(TF_Graph* graph, // while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { // DoSomethingWithOperation(oper); // } -extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos); +TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, + size_t* pos); // Write out a serialized representation of `graph` (as a GraphDef protocol // message) to `output_graph_def` (allocated by TF_NewBuffer()). +// `output_graph_def`'s underlying buffer will be freed when TF_DeleteBuffer() +// is called. // // May fail on very large graphs in the future. -extern void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph, + TF_Buffer* output_graph_def, + TF_Status* status); // TF_ImportGraphDefOptions holds options that can be passed to // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; -extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); -extern void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts); +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( + TF_ImportGraphDefOptions* opts); // Set the prefix to be prepended to the names of nodes in `graph_def` that will // be imported into `graph`. -extern void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, - const char* prefix); +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( + TF_ImportGraphDefOptions* opts, const char* prefix); // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. -extern void TF_ImportGraphDefOptionsAddInputMapping( +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst); +// Set any imported nodes with control input `src_name` to have that input +// replaced with `dst`. `src_name` refers to a node in the graph to be imported, +// `dst` references an operation already existing in the graph being imported +// into. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( + TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); + // Cause the imported graph to have a control dependency on `oper`. `oper` // should exist in the graph being imported into. -extern void TF_ImportGraphDefOptionsAddControlDependency( +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( TF_ImportGraphDefOptions* opts, TF_Operation* oper); // Add an output in `graph_def` to be returned via the `return_outputs` output // parameter of TF_GraphImportGraphDef(). If the output is remapped via an input // mapping, the corresponding existing tensor in `graph` will be returned. -extern void TF_ImportGraphDefOptionsAddReturnOutput( +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( TF_ImportGraphDefOptions* opts, const char* oper_name, int index); // Returns the number of return outputs added via // TF_ImportGraphDefOptionsAddReturnOutput(). -extern int TF_ImportGraphDefOptionsNumReturnOutputs( +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( const TF_ImportGraphDefOptions* opts); // Import the graph serialized in `graph_def` into `graph`. @@ -830,24 +902,103 @@ extern int TF_ImportGraphDefOptionsNumReturnOutputs( // result of TF_ImportGraphDefOptionsNumReturnOutputs()). If // `num_return_outputs` is non-zero, `return_outputs` must be of length // `num_return_outputs`. Otherwise it can be null. -extern void TF_GraphImportGraphDefWithReturnOutputs( +TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, int num_return_outputs, TF_Status* status); // Import the graph serialized in `graph_def` into `graph`. // Convenience function for when no return outputs have been added. -extern void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status); // Note: The following function may fail on very large protos in the future. -extern void TF_OperationToNodeDef(TF_Operation* oper, - TF_Buffer* output_node_def, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, + TF_Buffer* output_node_def, + TF_Status* status); -// TODO(andydavis): Function to add gradients to a graph. +typedef struct TF_WhileParams { + // The number of inputs to the while loop, i.e. the number of loop variables. + // This is the size of cond_inputs, body_inputs, and body_outputs. + const int ninputs; + + // The while condition graph. The inputs are the current values of the loop + // variables. The output should be a scalar boolean. + TF_Graph* const cond_graph; + const TF_Output* const cond_inputs; + TF_Output cond_output; + + // The loop body graph. The inputs are the current values of the loop + // variables. The outputs are the updated values of the loop variables. + TF_Graph* const body_graph; + const TF_Output* const body_inputs; + TF_Output* const body_outputs; + + // Unique null-terminated name for this while loop. This is used as a prefix + // for created operations. + const char* name; +} TF_WhileParams; + +// Creates a TF_WhileParams for creating a while loop in `g`. `inputs` are +// outputs that already exist in `g` used as initial values for the loop +// variables. +// +// The returned TF_WhileParams will have all fields initialized except +// `cond_output`, `body_outputs`, and `name`. The `body_outputs` buffer will be +// allocated to size `ninputs`. The caller should build `cond_graph` and +// `body_graph` starting from the inputs, and store the final outputs in +// `cond_output` and `body_outputs`. +// +// If `status` is OK, the caller must call either TF_FinishWhile or +// TF_AbortWhile on the returned TF_WhileParams. If `status` isn't OK, the +// returned TF_WhileParams is not valid, and the caller should not call +// TF_FinishWhile() or TF_AbortWhile(). +// +// Missing functionality (TODO): +// - Gradients +// - Reference-type inputs +// - Directly referencing external tensors from the cond/body graphs (this is +// possible in the Python API) +TF_CAPI_EXPORT extern TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, + int ninputs, + TF_Status* status); + +// Builds the while loop specified by `params` and returns the output tensors of +// the while loop in `outputs`. `outputs` should be allocated to size +// `params.ninputs`. +// +// `params` is no longer valid once this returns. +// +// Either this or TF_AbortWhile() must be called after a successful +// TF_NewWhile() call. +TF_CAPI_EXPORT extern void TF_FinishWhile(const TF_WhileParams* params, + TF_Status* status, + TF_Output* outputs); + +// Frees `params`s resources without building a while loop. `params` is no +// longer valid after this returns. Either this or TF_FinishWhile() must be +// called after a successful TF_NewWhile() call. +TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); + +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, + TF_Output* x, int nx, TF_Output* dx, + TF_Status* status, TF_Output* dy); // TODO(josh11b): Register OpDef, available to all operations added // to this graph. @@ -855,7 +1006,6 @@ extern void TF_OperationToNodeDef(TF_Operation* oper, // The following two may both benefit from a subgraph-definition API // that re-uses most of the graph-definition API. // TODO(andydavis): Add functions to a graph. -// TODO(yuanbyu): Add while loop to graph. // -------------------------------------------------------------------------- // API for driving Graph execution. @@ -867,12 +1017,9 @@ typedef struct TF_Session TF_Session; // *graph must be a valid graph (not deleted or nullptr). This function will // prevent the graph from being deleted until TF_DeleteSession() is called. // Does not take ownership of opts. -extern TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opts, - TF_Status* status); - -#ifndef __ANDROID__ -// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that -// the tensorflow/cc/saved_model:loader build target is Android friendly. +TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, + const TF_SessionOptions* opts, + TF_Status* status); // This function creates a new TF_Session (which is created on success) using // `session_options`, and then initializes state (restoring tensors and other @@ -888,17 +1035,16 @@ extern TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opts, // // If successful, populates `graph` with the contents of the Graph and // `meta_graph_def` with the MetaGraphDef of the loaded model. -TF_Session* TF_LoadSessionFromSavedModel( +TF_CAPI_EXPORT extern TF_Session* TF_LoadSessionFromSavedModel( const TF_SessionOptions* session_options, const TF_Buffer* run_options, const char* export_dir, const char* const* tags, int tags_len, TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status); -#endif // __ANDROID__ // Close a session. // // Contacts any other processes associated with the session, if applicable. // May not be called after TF_DeleteSession(). -extern void TF_CloseSession(TF_Session*, TF_Status* status); +TF_CAPI_EXPORT extern void TF_CloseSession(TF_Session*, TF_Status* status); // Destroy a session object. // @@ -906,7 +1052,7 @@ extern void TF_CloseSession(TF_Session*, TF_Status* status); // local resources associated with the session. The session may not be used // during or after this call (and the session drops its reference to the // corresponding graph). -extern void TF_DeleteSession(TF_Session*, TF_Status* status); +TF_CAPI_EXPORT extern void TF_DeleteSession(TF_Session*, TF_Status* status); // Run the graph associated with the session starting with the supplied inputs // (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). @@ -932,58 +1078,61 @@ extern void TF_DeleteSession(TF_Session*, TF_Status* status); // to the caller, which must eventually call TF_DeleteTensor on them. // // On failure, output_values[] contains NULLs. -extern void TF_SessionRun(TF_Session* session, - // RunOptions - const TF_Buffer* run_options, - // Input tensors - const TF_Output* inputs, - TF_Tensor* const* input_values, int ninputs, - // Output tensors - const TF_Output* outputs, TF_Tensor** output_values, - int noutputs, - // Target operations - const TF_Operation* const* target_opers, int ntargets, - // RunMetadata - TF_Buffer* run_metadata, - // Output status - TF_Status*); +TF_CAPI_EXPORT extern void TF_SessionRun( + TF_Session* session, + // RunOptions + const TF_Buffer* run_options, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // RunMetadata + TF_Buffer* run_metadata, + // Output status + TF_Status*); // Set up the graph with the intended feeds (inputs) and fetches (outputs) for a // sequence of partial run calls. // -// On success, returns a handle that is used for subsequent PRun calls. +// On success, returns a handle that is used for subsequent PRun calls. The +// handle should be deleted with TF_DeletePRunHandle when it is no longer +// needed. // // On failure, out_status contains a tensorflow::Status with an error // message. // NOTE: This is EXPERIMENTAL and subject to change. -extern void TF_SessionPRunSetup(TF_Session*, - // Input names - const TF_Output* inputs, int ninputs, - // Output names - const TF_Output* outputs, int noutputs, - // Target operations - const TF_Operation* const* target_opers, - int ntargets, - // Output handle - const char** handle, - // Output status - TF_Status*); +TF_CAPI_EXPORT extern void TF_SessionPRunSetup( + TF_Session*, + // Input names + const TF_Output* inputs, int ninputs, + // Output names + const TF_Output* outputs, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output handle + const char** handle, + // Output status + TF_Status*); // Continue to run the graph with additional feeds and fetches. The // execution state is uniquely identified by the handle. // NOTE: This is EXPERIMENTAL and subject to change. -extern void TF_SessionPRun(TF_Session*, const char* handle, - // Input tensors - const TF_Output* inputs, - TF_Tensor* const* input_values, int ninputs, - // Output tensors - const TF_Output* outputs, TF_Tensor** output_values, - int noutputs, - // Target operations - const TF_Operation* const* target_opers, - int ntargets, - // Output status - TF_Status*); +TF_CAPI_EXPORT extern void TF_SessionPRun( + TF_Session*, const char* handle, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output status + TF_Status*); + +// Deletes a handle allocated by TF_SessionPRunSetup. +// Once called, no more calls to TF_SessionPRun should be made. +TF_CAPI_EXPORT extern void TF_DeletePRunHandle(const char* handle); // -------------------------------------------------------------------------- // The deprecated session API. Please switch to the above instead of @@ -992,39 +1141,96 @@ extern void TF_SessionPRun(TF_Session*, const char* handle, typedef struct TF_DeprecatedSession TF_DeprecatedSession; -extern TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions*, +TF_CAPI_EXPORT extern TF_DeprecatedSession* TF_NewDeprecatedSession( + const TF_SessionOptions*, TF_Status* status); +TF_CAPI_EXPORT extern void TF_CloseDeprecatedSession(TF_DeprecatedSession*, TF_Status* status); -extern void TF_CloseDeprecatedSession(TF_DeprecatedSession*, TF_Status* status); -extern void TF_DeleteDeprecatedSession(TF_DeprecatedSession*, - TF_Status* status); -extern void TF_Reset(const TF_SessionOptions* opt, const char** containers, - int ncontainers, TF_Status* status); +TF_CAPI_EXPORT extern void TF_DeleteDeprecatedSession(TF_DeprecatedSession*, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_Reset(const TF_SessionOptions* opt, + const char** containers, int ncontainers, + TF_Status* status); // Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and // add the nodes in that GraphDef to the graph for the session. // // Prefer use of TF_Session and TF_GraphImportGraphDef over this. -extern void TF_ExtendGraph(TF_DeprecatedSession*, const void* proto, - size_t proto_len, TF_Status*); +TF_CAPI_EXPORT extern void TF_ExtendGraph(TF_DeprecatedSession*, + const void* proto, size_t proto_len, + TF_Status*); // See TF_SessionRun() above. -extern void TF_Run(TF_DeprecatedSession*, const TF_Buffer* run_options, - const char** input_names, TF_Tensor** inputs, int ninputs, - const char** output_names, TF_Tensor** outputs, int noutputs, - const char** target_oper_names, int ntargets, - TF_Buffer* run_metadata, TF_Status*); +TF_CAPI_EXPORT extern void TF_Run(TF_DeprecatedSession*, + const TF_Buffer* run_options, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Buffer* run_metadata, TF_Status*); // See TF_SessionPRunSetup() above. -extern void TF_PRunSetup(TF_DeprecatedSession*, const char** input_names, - int ninputs, const char** output_names, int noutputs, - const char** target_oper_names, int ntargets, - const char** handle, TF_Status*); +TF_CAPI_EXPORT extern void TF_PRunSetup(TF_DeprecatedSession*, + const char** input_names, int ninputs, + const char** output_names, int noutputs, + const char** target_oper_names, + int ntargets, const char** handle, + TF_Status*); // See TF_SessionPRun above. -extern void TF_PRun(TF_DeprecatedSession*, const char* handle, - const char** input_names, TF_Tensor** inputs, int ninputs, - const char** output_names, TF_Tensor** outputs, - int noutputs, const char** target_oper_names, int ntargets, - TF_Status*); +TF_CAPI_EXPORT extern void TF_PRun(TF_DeprecatedSession*, const char* handle, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Status*); + +typedef struct TF_DeviceList TF_DeviceList; + +// Lists all devices in a TF_Session. +// +// Caller takes ownership of the returned TF_DeviceList* which must eventually +// be freed with a call to TF_DeleteDeviceList. +TF_CAPI_EXPORT extern TF_DeviceList* TF_SessionListDevices(TF_Session* session, + TF_Status* status); + +// Lists all devices in a TF_Session. +// +// Caller takes ownership of the returned TF_DeviceList* which must eventually +// be freed with a call to TF_DeleteDeviceList. +TF_CAPI_EXPORT extern TF_DeviceList* TF_DeprecatedSessionListDevices( + TF_DeprecatedSession* session, TF_Status* status); + +// Deallocates the device list. +TF_CAPI_EXPORT extern void TF_DeleteDeviceList(TF_DeviceList* list); + +// Counts the number of elements in the device list. +TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); + +// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) +// The return value will be a pointer to a null terminated string. The caller +// must not modify or delete the string. It will be deallocated upon a call to +// TF_DeleteDeviceList. +// +// If index is out of bounds, an error code will be set in the status object, +// and a null pointer will be returned. +TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, + int index, TF_Status*); + +// Retrieves the type of the device at the given index. +// +// The caller must not modify or delete the string. It will be deallocated upon +// a call to TF_DeleteDeviceList. +// +// If index is out of bounds, an error code will be set in the status object, +// and a null pointer will be returned. +TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, + int index, TF_Status*); + +// Retrieve the amount of memory associated with a given device. +// +// If index is out of bounds, an error code will be set in the status object, +// and -1 will be returned. +TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( + const TF_DeviceList* list, int index, TF_Status*); // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels @@ -1043,19 +1249,19 @@ typedef struct TF_Library TF_Library; // The caller owns the library handle. // // On failure, place an error status in status and return NULL. -extern TF_Library* TF_LoadLibrary(const char* library_filename, - TF_Status* status); +TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename, + TF_Status* status); // Get the OpList of OpDefs defined in the library pointed by lib_handle. // // Returns a TF_Buffer. The memory pointed to by the result is owned by // lib_handle. The data in the buffer will be the serialized OpList proto for // ops defined in the library. -extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); +TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); // Frees the memory associated with the library handle. // Does NOT unload the library. -extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); +TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // Get the OpList of all OpDefs defined in this address space. // Returns a TF_Buffer, ownership of which is transferred to the caller @@ -1063,7 +1269,7 @@ extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // // The data in the buffer will be the serialized OpList proto for ops registered // in this address space. -extern TF_Buffer* TF_GetAllOpList(); +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); #ifdef __cplusplus } /* end extern "C" */ diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h new file mode 100644 index 00000000000..f17ac26ad96 --- /dev/null +++ b/tensorflow/c/c_api_internal.h @@ -0,0 +1,119 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api.h" + +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" + +// Internal structures used by the C API. These are likely to change and should +// not be depended on. + +struct TF_Status { + tensorflow::Status status; +}; + +struct TF_Tensor { + TF_DataType dtype; + tensorflow::TensorShape shape; + tensorflow::TensorBuffer* buffer; +}; + +struct TF_SessionOptions { + tensorflow::SessionOptions options; +}; + +struct TF_DeprecatedSession { + tensorflow::Session* session; +}; + +struct TF_Library { + void* lib_handle; + TF_Buffer op_list; +}; + +struct TF_Graph { + TF_Graph() + : graph(tensorflow::OpRegistry::Global()), + refiner(graph.versions().producer(), graph.op_registry()), + num_sessions(0), + delete_requested(false), + parent(nullptr), + parent_inputs(nullptr) {} + tensorflow::mutex mu; + tensorflow::Graph graph GUARDED_BY(mu); + + // Runs shape inference. + tensorflow::ShapeRefiner refiner GUARDED_BY(mu); + + // Maps from name of an operation to the Node* in 'graph'. + std::unordered_map name_map + GUARDED_BY(mu); + + // TF_Graph may only / must be deleted when + // num_sessions == 0 && delete_requested == true + + // num_sessions incremented by TF_NewSession, and decremented by + // TF_DeleteSession. + int num_sessions GUARDED_BY(mu); + bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph + + // Used to link graphs contained in TF_WhileParams to the parent graph that + // will eventually contain the full while loop. + TF_Graph* parent; + TF_Output* parent_inputs; +}; + +struct TF_OperationDescription { + TF_OperationDescription(TF_Graph* g, const char* op_type, + const char* node_name) + : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} + + tensorflow::NodeBuilder node_builder; + TF_Graph* graph; + std::vector colocation_constraints; +}; + +struct TF_Operation { + tensorflow::Node node; +}; + +struct TF_Session { + TF_Session(tensorflow::Session* s, TF_Graph* g) + : session(s), graph(g), last_num_graph_nodes(0) {} + tensorflow::Session* session; + TF_Graph* graph; + tensorflow::mutex mu; + int last_num_graph_nodes; +}; + +struct TF_ImportGraphDefOptions { + tensorflow::ImportGraphDefOptions opts; +}; + +struct TF_DeviceList { + std::vector response; +}; diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 5591409d99b..04540bd793d 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/util/equal_graph_def.h" using tensorflow::int32; using tensorflow::string; @@ -105,6 +107,22 @@ TEST(CAPI, AllocateTensor) { TF_DeleteTensor(t); } +TEST(CAPI, MaybeMove) { + const int num_bytes = 6 * sizeof(float); + float* values = + reinterpret_cast(tensorflow::cpu_allocator()->AllocateRaw( + EIGEN_MAX_ALIGN_BYTES, num_bytes)); + int64_t dims[] = {2, 3}; + bool deallocator_called = false; + TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes, + &Deallocator, &deallocator_called); + + TF_Tensor* o = TF_TensorMaybeMove(t); + ASSERT_TRUE(o == nullptr); // It is unsafe to move memory TF might not own. + TF_DeleteTensor(t); + EXPECT_TRUE(deallocator_called); +} + TEST(CAPI, LibraryLoadFunctions) { // Load the library. TF_Status* status = TF_NewStatus(); @@ -261,6 +279,19 @@ static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +// Create a tensor with values of type TF_INT8 provided by `values`. +static TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, + const char* values) { + int64_t num_values = 1; + for (int i = 0; i < num_dims; ++i) { + num_values *= dims[i]; + } + TF_Tensor* t = + TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values); + memcpy(TF_TensorData(t), values, sizeof(char) * num_values); + return t; +} + static TF_Tensor* Int32Tensor(int32 v) { const int num_bytes = sizeof(int32); int32* values = new int32[1]; @@ -269,29 +300,44 @@ static TF_Tensor* Int32Tensor(int32 v) { &Int32Deallocator, nullptr); } -TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s) { - TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", "feed"); +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, + const char* name = "feed") { + TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); TF_SetAttrType(desc, "dtype", TF_INT32); return TF_FinishOperation(desc, s); } -TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) { - unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); - TF_OperationDescription* desc = TF_NewOperation(graph, "Const", "scalar"); - TF_SetAttrTensor(desc, "value", tensor.get(), s); +TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, + const char* name = "const") { + TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); + TF_SetAttrTensor(desc, "value", t, s); if (TF_GetCode(s) != TF_OK) return nullptr; - TF_SetAttrType(desc, "dtype", TF_INT32); + TF_SetAttrType(desc, "dtype", TF_TensorType(t)); return TF_FinishOperation(desc, s); } +TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar") { + unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s) { - TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", "add"); + TF_Status* s, const char* name = "add") { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; TF_AddInputList(desc, add_inputs, 2); return TF_FinishOperation(desc, s); } +TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, + const char* name = "add") { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); + TF_Output inputs[2] = {l, r}; + TF_AddInputList(desc, inputs, 2); + return TF_FinishOperation(desc, s); +} + TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg"); TF_Output neg_input = {n, 0}; @@ -299,6 +345,14 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { return TF_FinishOperation(desc, s); } +TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, + TF_Status* s) { + TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than"); + TF_AddInput(desc, l); + TF_AddInput(desc, r); + return TF_FinishOperation(desc, s); +} + bool IsPlaceholder(const NodeDef& node_def) { if (node_def.op() != "Placeholder" || node_def.name() != "feed") { return false; @@ -667,6 +721,28 @@ TEST(CAPI, Graph) { TF_DeleteStatus(s); } +/* +TODO(skyewm): this test currently DCHECKs, change to bad status + +TEST(CAPI, InputFromDifferentGraphError) { + TF_Status* s = TF_NewStatus(); + TF_Graph* g1 = TF_NewGraph(); + TF_Graph* g2 = TF_NewGraph(); + + TF_Operation* feed = Placeholder(g1, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Attempt to create node in g2 with input from g1 + Neg(feed, g2, s); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); + EXPECT_STREQ("foo", TF_Message(s)); + + TF_DeleteGraph(g1); + TF_DeleteGraph(g2); + TF_DeleteStatus(s); +} +*/ + TEST(CAPI, ImportGraphDef) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -765,6 +841,33 @@ TEST(CAPI, ImportGraphDef) { EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed2, control_inputs[1]); + // Export to a graph def so we can import a graph with control dependencies + TF_DeleteBuffer(graph_def); + graph_def = TF_NewBuffer(); + TF_GraphToGraphDef(graph, graph_def, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Import again, with remapped control dependency, into the same graph + TF_DeleteImportGraphDefOptions(opts); + opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); + TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); + TF_GraphImportGraphDef(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar4 = + TF_GraphOperationByName(graph, "imported4/imported3/scalar"); + TF_Operation* feed4 = + TF_GraphOperationByName(graph, "imported4/imported2/feed"); + + // Check that imported `imported3/scalar` has remapped control dep from + // original graph and imported control dep + num_control_inputs = TF_OperationGetControlInputs( + scalar4, control_inputs, TF_OperationNumControlInputs(scalar4)); + ASSERT_EQ(2, num_control_inputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed4, control_inputs[1]); + TF_DeleteImportGraphDefOptions(opts); TF_DeleteBuffer(graph_def); @@ -784,7 +887,7 @@ class CSession { TF_DeleteSessionOptions(opts); } - CSession(TF_Session* session) { session_ = session; } + explicit CSession(TF_Session* session) : session_(session) {} ~CSession() { TF_Status* s = TF_NewStatus(); @@ -793,8 +896,7 @@ class CSession { TF_DeleteStatus(s); } - void SetInputs( - std::initializer_list> inputs) { + void SetInputs(std::vector> inputs) { DeleteInputValues(); inputs_.clear(); for (const auto& p : inputs) { @@ -811,6 +913,11 @@ class CSession { } } + void SetOutputs(const std::vector& outputs) { + ResetOutputValues(); + outputs_ = outputs; + } + void SetTargets(std::initializer_list targets) { targets_.clear(); for (TF_Operation* t : targets) { @@ -937,6 +1044,103 @@ TEST(CAPI, Session) { TF_DeleteStatus(s); } +TEST(CAPI, SessionPRun) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Construct the graph: A + 2 + B + TF_Operation* a = Placeholder(graph, s, "A"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* b = Placeholder(graph, s, "B"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* two = ScalarConst(2, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* plus2 = Add(a, two, graph, s, "plus2"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* plusB = Add(plus2, b, graph, s, "plusB"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Setup a session and a partial run handle. The partial run will allow + // computation of A + 2 + B in two phases (calls to TF_SessionPRun): + // 1. Feed A and get (A+2) + // 2. Feed B and get (A+2)+B + TF_SessionOptions* opts = TF_NewSessionOptions(); + TF_Session* sess = TF_NewSession(graph, opts, s); + TF_DeleteSessionOptions(opts); + + TF_Output feeds[] = {TF_Output{a, 0}, TF_Output{b, 0}}; + TF_Output fetches[] = {TF_Output{plus2, 0}, TF_Output{plusB, 0}}; + + const char* handle = nullptr; + TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches, + TF_ARRAYSIZE(fetches), nullptr, 0, &handle, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Feed A and fetch A + 2. + TF_Output feeds1[] = {TF_Output{a, 0}}; + TF_Output fetches1[] = {TF_Output{plus2, 0}}; + TF_Tensor* feedValues1[] = {Int32Tensor(1)}; + TF_Tensor* fetchValues1[1]; + TF_SessionPRun(sess, handle, feeds1, feedValues1, 1, fetches1, fetchValues1, + 1, nullptr, 0, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(3, *(static_cast(TF_TensorData(fetchValues1[0])))); + TF_DeleteTensor(feedValues1[0]); + TF_DeleteTensor(fetchValues1[0]); + + // Feed B and fetch (A + 2) + B. + TF_Output feeds2[] = {TF_Output{b, 0}}; + TF_Output fetches2[] = {TF_Output{plusB, 0}}; + TF_Tensor* feedValues2[] = {Int32Tensor(4)}; + TF_Tensor* fetchValues2[1]; + TF_SessionPRun(sess, handle, feeds2, feedValues2, 1, fetches2, fetchValues2, + 1, nullptr, 0, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(7, *(static_cast(TF_TensorData(fetchValues2[0])))); + TF_DeleteTensor(feedValues2[0]); + TF_DeleteTensor(fetchValues2[0]); + + // Clean up. + TF_DeletePRunHandle(handle); + TF_DeleteSession(sess, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI, ShapeInferenceError) { + // TF_FinishOperation should fail if the shape of the added operation cannot + // be inferred. + TF_Status* status = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Create this failure by trying to add two nodes with incompatible shapes + // (A tensor with shape [2] and a tensor with shape [3] cannot be added). + const char data[] = {1, 2, 3}; + const int64_t vec2_dims[] = {2}; + unique_tensor_ptr vec2_tensor( + Int8Tensor(vec2_dims, TF_ARRAYSIZE(vec2_dims), data), TF_DeleteTensor); + TF_Operation* vec2 = Const(vec2_tensor.get(), graph, status, "vec2"); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const int64_t vec3_dims[] = {3}; + unique_tensor_ptr vec3_tensor( + Int8Tensor(vec3_dims, TF_ARRAYSIZE(vec3_dims), data), TF_DeleteTensor); + TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3"); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Operation* add = Add(vec2, vec3, graph, status); + ASSERT_NE(TF_OK, TF_GetCode(status)); + ASSERT_TRUE(add == nullptr); + + TF_DeleteGraph(graph); + TF_DeleteStatus(status); +} + TEST(CAPI, ColocateWith) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -1068,16 +1272,582 @@ TEST(CAPI, SavedModelNullArgsAreValid) { TF_DeleteStatus(s); } -// Create a tensor with values of type TF_INT8 provided by `values`. -TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { - int64_t num_values = 1; - for (int i = 0; i < num_dims; ++i) { - num_values *= dims[i]; +class CApiWhileLoopTest : public ::testing::Test { + protected: + CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {} + + ~CApiWhileLoopTest() override { + TF_DeleteGraph(graph_); + TF_DeleteStatus(s_); } - TF_Tensor* t = - TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values); - memcpy(TF_TensorData(t), values, sizeof(char) * num_values); - return t; + + void Init(int ninputs) { + DCHECK(inputs_.empty()); + DCHECK_GT(ninputs, 0); + + for (int i = 0; i < ninputs; ++i) { + TF_Operation* placeholder = Placeholder( + graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str()); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inputs_.push_back({placeholder, 0}); + } + + original_graph_description_ = GraphDebugString(); + + params_.reset(new TF_WhileParams( + TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_))); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_EQ(original_graph_description_, GraphDebugString()) + << "TF_NewWhile() altered graph"; + + params_->name = "test_loop"; + + // Initialize outputs_ so we can easily detect errors/bugs + outputs_.resize(ninputs, {nullptr, -1}); + } + + void ExpectOK() { + TF_FinishWhile(params_.get(), s_, &outputs_[0]); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void ExpectError(TF_Code expected_code, const string& expected_msg) { + TF_FinishWhile(params_.get(), s_, &outputs_[0]); + EXPECT_EQ(expected_code, TF_GetCode(s_)); + EXPECT_EQ(expected_msg, TF_Message(s_)); + // TODO(skyewm): this assert is currently broken. Fix or remove guarantee. + // ASSERT_EQ(original_graph_description_, GraphDebugString()) << + // "TF_FinishWhile() altered graph on error"; + } + + void Run(std::initializer_list input_values) { + DCHECK_EQ(inputs_.size(), input_values.size()); + std::vector> inputs(inputs_.size()); + int i = 0; + for (int v : input_values) { + inputs[i] = {inputs_[i].oper, Int32Tensor(v)}; + ++i; + } + csession_.reset(new CSession(graph_, s_)); + csession_->SetInputs(inputs); + csession_->SetOutputs(outputs_); + csession_->Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void ExpectOutputValue(int idx, int expected_value) { + TF_Tensor* out = csession_->output_tensor(idx); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* data = static_cast(TF_TensorData(out)); + EXPECT_EQ(expected_value, *data); + } + + // Create a valid conditional graph. Useful for testing unrelated errors. + void CreateCondGraph() { + TF_Operation* one = ScalarConst(1, params_->cond_graph, s_); + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + } + + string GraphDebugString() const { + TF_Buffer* buf = TF_NewBuffer(); + TF_GraphToGraphDef(graph_, buf, s_); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + GraphDef def; + bool success = def.ParseFromArray(buf->data, buf->length); + DCHECK(success); + TF_DeleteBuffer(buf); + return def.DebugString(); + } + + TF_Status* s_; + TF_Graph* graph_; + std::vector inputs_; // The inputs to the while loop + std::vector outputs_; // The final outputs of the while loop + std::unique_ptr params_; + std::unique_ptr csession_; + + private: + // Used to verify that errors don't change graph_ + string original_graph_description_; +}; + +TEST_F(CApiWhileLoopTest, BasicLoop) { + Init(2); + + // Validate TF_WhileParams returned by TF_NewWhile() + EXPECT_TRUE(params_->body_graph != nullptr); + EXPECT_TRUE(params_->cond_graph != nullptr); + + EXPECT_EQ(params_->ninputs, 2); + + ASSERT_TRUE(params_->cond_inputs != nullptr); + ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr); + EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr); + + ASSERT_TRUE(params_->body_inputs != nullptr); + EXPECT_TRUE(params_->body_inputs[0].oper != nullptr); + EXPECT_TRUE(params_->body_inputs[1].oper != nullptr); + + ASSERT_TRUE(params_->body_outputs != nullptr); + + // Create loop: while (input1 < input2) input1 += input2 + 1 + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], params_->cond_inputs[1], + params_->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + + TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1], + params_->body_graph, s_, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* one = ScalarConst(1, params_->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->body_outputs[0] = {add2, 0}; + params_->body_outputs[1] = params_->body_inputs[1]; + + // Finalize while loop + ExpectOK(); + + // Validate while loop outputs returned by TF_FinishWhile() + EXPECT_TRUE(outputs_[0].oper != nullptr); + EXPECT_GE(outputs_[0].index, 0); + EXPECT_TRUE(outputs_[1].oper != nullptr); + EXPECT_GE(outputs_[1].index, 0); + + // Run the graph + Run({-9, 2}); + ExpectOutputValue(0, 3); + ExpectOutputValue(1, 2); +} + +TEST_F(CApiWhileLoopTest, NestedLoop) { + Init(2); + // Create nested loop: + // while (input1 < 6) { + // inner_input1 = input1 + // while (inner_input1 < 3) { + // input2 += 1 + // inner_input1 += 2 + // } + // input1 += input2 + // } + // + // Expected execution with initial values input1 = input2 = 0: + // + // outer inner inner_ + // step# step# input1 input2 input1 + // ------------------------------------ + // 0 0 0 0 0 + // 0 1 0 1 2 + // 0 2 0 2 4 + // 0 - 2 2 - + // 1 0 2 2 2 + // 1 1 2 3 4 + // 1 - 5 3 - + // 2 0 5 3 5 + // 2 - 8 3 - + + // Create outer cond graph + TF_Operation* six = ScalarConst(6, params_->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + + // Create outer body graph + // Init inner graph + TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]}; + TF_WhileParams inner_params = + TF_NewWhile(params_->body_graph, inner_inputs, 2, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.name = "inner_loop"; + + // Create inner cond graph + TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* inner_less_than = LessThan( + inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.cond_output = {inner_less_than, 0}; + + // Create inner body graph + TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* input2_add = + Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.body_outputs[1] = {input2_add, 0}; + + TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two, + inner_params.body_graph, s_, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.body_outputs[0] = {inner_input1_add, 0}; + + // Finalize inner graph + TF_Output inner_outputs[2] = {{nullptr, -1}}; + TF_FinishWhile(&inner_params, s_, inner_outputs); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* input1_add = + Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->body_outputs[0] = {input1_add, 0}; + + params_->body_outputs[1] = inner_outputs[1]; + + // Finalize outer graph + ExpectOK(); + + // Check for a few expected nodes + const char* node_name = "test_loop/cond/scalar"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + node_name = "test_loop/body/add"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + node_name = "test_loop/body/inner_loop/body/one"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + node_name = "test_loop/body/inner_loop/cond/less_than"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + + // Run the graph + Run({0, 0}); + ExpectOutputValue(0, 8); + ExpectOutputValue(1, 3); +} + +TEST_F(CApiWhileLoopTest, BadCondOutput) { + Init(1); + params_->body_outputs[0] = params_->body_inputs[0]; + ExpectError(TF_INVALID_ARGUMENT, + "TF_WhileParams `cond_output` field isn't set"); +} + +TEST_F(CApiWhileLoopTest, BadBodyOutput) { + Init(1); + CreateCondGraph(); + ExpectError(TF_INVALID_ARGUMENT, + "TF_WhileParams `body_outputs[0]` field isn't set"); +} + +TEST_F(CApiWhileLoopTest, NullName) { + Init(1); + CreateCondGraph(); + params_->body_outputs[0] = params_->body_inputs[0]; + params_->name = nullptr; + ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null"); +} + +TEST_F(CApiWhileLoopTest, WrongGraph) { + Init(1); + CreateCondGraph(); + // Set body output to output from outer graph + params_->body_outputs[0] = inputs_[0]; + // TODO(skyewm): improve error message + ExpectError(TF_INVALID_ARGUMENT, + "Requested return node 'p0' not found in graph def"); +} + +TEST_F(CApiWhileLoopTest, BadTypes) { + Init(1); + CreateCondGraph(); + // Op that has a float input + output + TF_OperationDescription* desc = TF_NewOperation( + params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op"); + TF_AddInput(desc, params_->body_inputs[0]); + TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + string msg(TF_Message(s_)); + EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while " + "building NodeDef 'float_op'"), + msg.npos); + TF_AbortWhile(params_.get()); +} + +REGISTER_OP("TestOpWithNoGradient") + .Input("x: T") + .Output("y: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Test op with no grad registered. + +x: input +y: output +)doc") + .SetShapeFn(tensorflow::shape_inference::UnknownShape); + +class CApiGradientsTest : public ::testing::Test { + protected: + CApiGradientsTest() + : s_(TF_NewStatus()), + graph_(TF_NewGraph()), + expected_graph_(TF_NewGraph()) {} + + ~CApiGradientsTest() override { + TF_DeleteGraph(graph_); + TF_DeleteGraph(expected_graph_); + TF_DeleteStatus(s_); + } + + void TestGradientsSuccess(bool grad_inputs_provided) { + TF_Output inputs[2]; + TF_Output outputs[1]; + TF_Output grad_outputs[2]; + TF_Output expected_grad_outputs[2]; + + BuildSuccessGraph(inputs, outputs); + BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); + + AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs); + + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Compare that the graphs match. + GraphDef expected_gdef; + GraphDef gdef; + EXPECT_TRUE(GetGraphDef(expected_graph_, &expected_gdef)); + EXPECT_TRUE(GetGraphDef(graph_, &gdef)); + TF_EXPECT_GRAPH_EQ(expected_gdef, gdef); + + // Compare that the output of the gradients of both graphs match. + RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs); + } + + void TestGradientsError(bool grad_inputs_provided) { + TF_Output inputs[1]; + TF_Output outputs[1]; + TF_Output grad_outputs[1]; + + BuildErrorGraph(inputs, outputs); + + AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs); + + string expected_msg = + "No gradient defined for op: TestOpWithNoGradient. Please see " + "https://www.tensorflow.org/code/" + "tensorflow/cc/gradients/README.md" + " for instructions on how to add C++ gradients."; + EXPECT_EQ(expected_msg, TF_Message(s_)); + } + + // Run the graph and ensure that the gradient values are as expected. + void RunGraphsAndCompareOutputs(TF_Output* grad_outputs, + TF_Output* expected_grad_outputs) { + std::unique_ptr csession(new CSession(graph_, s_)); + std::unique_ptr expected_csession( + new CSession(expected_graph_, s_)); + + std::vector grad_outputs_vec; + grad_outputs_vec.assign(grad_outputs, grad_outputs + 2); + csession->SetOutputs(grad_outputs_vec); + csession->Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out0 = csession->output_tensor(0); + TF_Tensor* out1 = csession->output_tensor(1); + + std::vector expected_grad_outputs_vec; + expected_grad_outputs_vec.assign(expected_grad_outputs, + expected_grad_outputs + 2); + expected_csession->SetOutputs(expected_grad_outputs_vec); + expected_csession->Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* expected_out0 = expected_csession->output_tensor(0); + TF_Tensor* expected_out1 = expected_csession->output_tensor(1); + + CompareTensors(out0, expected_out0); + CompareTensors(out1, expected_out1); + } + + void CompareTensors(TF_Tensor* a, TF_Tensor* b) { + float* a_data = static_cast(TF_TensorData(a)); + float* b_data = static_cast(TF_TensorData(b)); + EXPECT_EQ(*a_data, *b_data); + } + + void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs, + TF_Output* outputs, int noutputs, TF_Output* grad_outputs) { + if (grad_inputs_provided) { + TF_Output grad_inputs[1]; + const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0}; + TF_Operation* grad_inputs_op = + FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); + grad_inputs[0] = TF_Output{grad_inputs_op, 0}; + TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs, + s_, grad_outputs); + } else { + TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_, + grad_outputs); + } + } + + void BuildErrorGraph(TF_Output* inputs, TF_Output* outputs) { + const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; + TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0"); + TF_Operation* nograd = NoGradientOp(graph_, s_, const0, "NoGrad"); + inputs[0] = TF_Output{const0, 0}; + outputs[0] = TF_Output{nograd, 0}; + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) { + // Construct the following graph: + // | + // z| + // | + // MatMul + // / \ + // ^ ^ + // | | + // x| y| + // | | + // | | + // Const_0 Const_1 + // + const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; + const float const1_val[] = {1.0, 0.0, 0.0, 1.0}; + TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0"); + TF_Operation* const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1"); + TF_Operation* matmul = MatMul(graph_, s_, const0, const1, "MatMul"); + inputs[0] = TF_Output{const0, 0}; + inputs[1] = TF_Output{const1, 0}; + outputs[0] = TF_Output{matmul, 0}; + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void BuildExpectedGraph(bool grad_inputs_provided, + TF_Output* expected_grad_outputs) { + // The expected graph looks like this if grad_inputs_provided. + // If grad_inputs_provided is false, Const_0 will be a OnesLike op. + // ^ ^ + // dy| dx| // MatMul Gradient Graph + // | | + // MatMul_2 MatMul_1 + // ^ ^ ^ ^ + // | |----------| | + // | ^ | + // | dz| | + // | | | + // | Const_3 | + // | | + // | ^ | + // | z| | // MatMul Forward Graph + // | | | + // | MatMul | + // | / \ | + // | ^ ^ | + // | | | | + // |---x| y|----| + // | | + // | | + // Const_0 Const_1 + // + const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; + const float const1_val[] = {1.0, 0.0, 0.0, 1.0}; + TF_Operation* const0 = + FloatConst2x2(expected_graph_, s_, const0_val, "Const_0"); + TF_Operation* const1 = + FloatConst2x2(expected_graph_, s_, const1_val, "Const_1"); + TF_Operation* matmul = + MatMul(expected_graph_, s_, const0, const1, "MatMul"); + + TF_Operation* const3; + if (grad_inputs_provided) { + const float const3_val[] = {1.0, 1.0, 1.0, 1.0}; + const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs"); + } else { + const3 = OnesLike(expected_graph_, s_, matmul, "OnesLike"); + } + + TF_Operation* matmul1 = + MatMul(expected_graph_, s_, const3, const1, "MatMul_1", false, true); + TF_Operation* matmul2 = + MatMul(expected_graph_, s_, const0, const3, "MatMul_2", true, false); + expected_grad_outputs[0] = {matmul1, 0}; + expected_grad_outputs[1] = {matmul2, 0}; + } + + TF_Tensor* FloatTensor2x2(const float* values) { + const int64_t dims[2] = {2, 2}; + TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4); + memcpy(TF_TensorData(t), values, sizeof(float) * 4); + return t; + } + + TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s, + const float* values, const char* name) { + unique_tensor_ptr tensor(FloatTensor2x2(values), TF_DeleteTensor); + TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); + TF_SetAttrTensor(desc, "value", tensor.get(), s); + if (TF_GetCode(s) != TF_OK) return nullptr; + TF_SetAttrType(desc, "dtype", TF_FLOAT); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l, + TF_Operation* r, const char* name, + bool transpose_a = false, bool transpose_b = false) { + TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name); + if (transpose_a) { + TF_SetAttrBool(desc, "transpose_a", 1); + } + if (transpose_b) { + TF_SetAttrBool(desc, "transpose_b", 1); + } + TF_AddInput(desc, {l, 0}); + TF_AddInput(desc, {r, 0}); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Operation* OnesLike(TF_Graph* graph, TF_Status* s, TF_Operation* in, + const char* name) { + TF_OperationDescription* desc = TF_NewOperation(graph, "OnesLike", name); + TF_AddInput(desc, {in, 0}); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Operation* NoGradientOp(TF_Graph* graph, TF_Status* s, TF_Operation* in, + const char* name) { + TF_OperationDescription* desc = + TF_NewOperation(graph, "TestOpWithNoGradient", name); + TF_AddInput(desc, {in, 0}); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Status* s_; + TF_Graph* graph_; + TF_Graph* expected_graph_; +}; + +TEST_F(CApiGradientsTest, Gradients_GradInputs) { TestGradientsSuccess(true); } + +TEST_F(CApiGradientsTest, Gradients_NoGradInputs) { + TestGradientsSuccess(false); +} + +TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) { + TestGradientsError(true); +} + +TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { + TestGradientsError(false); } void StringVectorToArrays(const std::vector& v, @@ -1095,9 +1865,13 @@ void StringVectorToArrays(const std::vector& v, // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other // will have list(type). -#define ATTR_TEST_REGISTER_OP(type) \ - REGISTER_OP("CApiAttributesTestOp" #type).Attr("v: " #type); \ - REGISTER_OP("CApiAttributesTestOpList" #type).Attr("v: list(" #type ")") +#define ATTR_TEST_REGISTER_OP(type) \ + REGISTER_OP("CApiAttributesTestOp" #type) \ + .Attr("v: " #type) \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape); \ + REGISTER_OP("CApiAttributesTestOpList" #type) \ + .Attr("v: list(" #type ")") \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape) ATTR_TEST_REGISTER_OP(string); ATTR_TEST_REGISTER_OP(int); ATTR_TEST_REGISTER_OP(float); @@ -1504,8 +2278,8 @@ TEST_F(CApiAttributesTest, TensorList) { EXPECT_EQ(TF_INT8, TF_TensorType(v)) << i; EXPECT_EQ(tensor_ndims[i], TF_NumDims(v)) << i; for (int j = 0; j < TF_NumDims(v); ++j) { - EXPECT_EQ(tensor_dims[i][j], TF_Dim(v, j)) << "Tensor #" << i - << ", dimension #" << j; + EXPECT_EQ(tensor_dims[i][j], TF_Dim(v, j)) + << "Tensor #" << i << ", dimension #" << j; } EXPECT_EQ(sizeof(char) * tensor_size[i], TF_TensorByteSize(v)) << i; EXPECT_EQ(0, diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 17b3f93193d..e7b9bca5b50 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -58,6 +58,7 @@ CheckpointReader::CheckpointReader(const string& filename, CheckpointReader::~CheckpointReader() { delete var_to_shape_map_ptr_; delete reader_; + delete v2_reader_; } bool CheckpointReader::HasTensor(const string& name) const { diff --git a/tensorflow/c/exported_symbols.lds b/tensorflow/c/exported_symbols.lds new file mode 100644 index 00000000000..a14bdaa48be --- /dev/null +++ b/tensorflow/c/exported_symbols.lds @@ -0,0 +1 @@ +_TF_* diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh new file mode 100755 index 00000000000..02a6a58b615 --- /dev/null +++ b/tensorflow/c/generate-pc.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +TF_PREFIX='/usr/local' + +usage() { + echo "Usage: $0 OPTIONS" + echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-v, --version\tset TensorFlow version" + echo -e "-h, --help\tdisplay this message" +} + +[ $# == 0 ] && usage && exit 0 + +# read the options +ARGS=$(getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@") +eval set -- "$ARGS" + +# extract options and their arguments into variables. +while true ; do + case "$1" in + -h|--help) usage ; exit ;; + -p|--prefix) + case "$2" in + "") shift 2 ;; + *) TF_PREFIX=$2 ; shift 2 ;; + esac ;; + -v|--version) + case "$2" in + "") shift 2 ;; + *) TF_VERSION=$2 ; shift 2 ;; + esac ;; + --) shift ; break ;; + *) echo "Internal error! Try '$0 --help' for more information." ; exit 1 ;; + esac +done + +[ -z $TF_VERSION ] && echo "Specify a version using -v or --version" && exit 1 + +echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" + +cat << EOF > tensorflow.pc +prefix=${TF_PREFIX} +exec_prefix=\${prefix} +libdir=\${exec_prefix}/lib +includedir=\${prefix}/include + +Name: TensorFlow +Version: ${TF_VERSION} +Description: Library for computation using data flow graphs for scalable machine learning +Requires: +Libs: -L\${libdir} -ltensorflow +Cflags: -I\${includedir} +EOF diff --git a/tensorflow/c/version_script.lds b/tensorflow/c/version_script.lds new file mode 100644 index 00000000000..455bd7362bb --- /dev/null +++ b/tensorflow/c/version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + # Export symbols in c_api.h. + global: + TF_*; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 054278bbf77..f89cc6384b3 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -8,8 +8,6 @@ package( licenses(["notice"]) # Apache 2.0 -exports_files(["LICENSE"]) - load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -36,6 +34,7 @@ cc_library( tf_cc_test( name = "framework_gradients_test", + size = "small", srcs = ["framework/gradients_test.cc"], deps = [ ":cc_ops", @@ -44,8 +43,8 @@ tf_cc_test( ":gradients", ":testutil", "//tensorflow/core:all_kernels", - "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -59,7 +58,6 @@ cc_library( deps = [ ":cc_ops", ":client_session", - ":grad_op_registry", ":gradients", ":ops", ":scope", @@ -72,6 +70,7 @@ cc_library( tf_cc_test( name = "framework_gradient_checker_test", + size = "small", srcs = ["framework/gradient_checker_test.cc"], deps = [ ":cc_ops", @@ -80,8 +79,8 @@ tf_cc_test( ":gradient_checker", ":testutil", "//tensorflow/core:all_kernels", - "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -93,6 +92,7 @@ cc_library( deps = [ ":array_grad", ":math_grad", + ":nn_grad", ], ) @@ -124,7 +124,10 @@ cc_library_with_android_deps( cc_library_with_android_deps( name = "scope", - srcs = ["framework/scope.cc"], + srcs = [ + "framework/scope.cc", + "framework/scope_internal.h", + ], hdrs = ["framework/scope.h"], android_deps = ["//tensorflow/core:android_tensorflow_lib"], common_deps = [ @@ -138,8 +141,18 @@ cc_library_with_android_deps( ], ) +cc_library_with_android_deps( + name = "scope_internal", + hdrs = ["framework/scope_internal.h"], + common_deps = [ + ":scope", + ], + deps = [], +) + tf_cc_test( name = "framework_scope_test", + size = "small", srcs = ["framework/scope_test.cc"], deps = [ ":ops", @@ -169,6 +182,7 @@ cc_library_with_android_deps( tf_cc_test( name = "client_client_session_test", + size = "small", srcs = ["client/client_session_test.cc"], deps = [ ":cc_ops", @@ -203,6 +217,7 @@ cc_library_with_android_deps( tf_cc_test( name = "ops_const_op_test", + size = "small", srcs = ["ops/const_op_test.cc"], deps = [ ":const_op", @@ -231,11 +246,13 @@ cc_library( ":cc_ops_internal", ":grad_op_registry", ":gradients", + "//tensorflow/core:lib_proto_parsing", ], ) tf_cc_test( name = "gradients_array_grad_test", + size = "small", srcs = ["gradients/array_grad_test.cc"], deps = [ ":array_grad", @@ -266,6 +283,7 @@ cc_library( tf_cc_test( name = "gradients_math_grad_test", + size = "small", srcs = ["gradients/math_grad_test.cc"], deps = [ ":cc_ops", @@ -296,6 +314,7 @@ cc_library( tf_cc_test( name = "gradients_nn_grad_test", + size = "small", srcs = ["gradients/nn_grad_test.cc"], deps = [ ":cc_ops", @@ -315,6 +334,7 @@ tf_gen_op_wrappers_cc( name = "cc_ops", op_lib_names = [ "array_ops", + "audio_ops", "candidate_sampling_ops", "control_flow_ops", "data_flow_ops", @@ -343,6 +363,7 @@ tf_gen_op_wrappers_cc( tf_cc_test( name = "framework_cc_ops_test", + size = "small", srcs = ["framework/cc_ops_test.cc"], deps = [ ":cc_ops", @@ -376,6 +397,34 @@ tf_gen_op_wrappers_cc( visibility = ["//tensorflow:internal"], ) +tf_gen_op_wrappers_cc( + name = "functional_ops", + include_internal_ops = 1, + op_lib_names = [ + "functional_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + +tf_gen_op_wrappers_cc( + name = "resource_variable_ops", + include_internal_ops = 1, + op_lib_names = [ + "resource_variable_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + +tf_gen_op_wrappers_cc( + name = "remote_fused_graph_ops", + op_lib_names = [ + "remote_fused_graph_ops", + ], + pkg = "//tensorflow/core", +) + cc_library_with_android_deps( name = "cc_op_gen_main", srcs = [ @@ -414,7 +463,6 @@ cc_library( ":client_session", ":ops", ":scope", - "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", "//tensorflow/core:lib_internal", "//tensorflow/core:tensorflow", @@ -433,13 +481,25 @@ cc_binary( name = "tutorials_example_trainer", srcs = ["tutorials/example_trainer.cc"], copts = tf_copts(), - linkopts = [ - "-lpthread", - "-lm", - ], + linkopts = select({ + "//tensorflow:windows": [], + "//tensorflow:windows_msvc": [], + "//tensorflow:darwin": [ + "-lm", + "-lpthread", + ], + "//tensorflow:ios": [ + "-lm", + "-lpthread", + ], + "//conditions:default": [ + "-lm", + "-lpthread", + "-lrt", + ], + }), deps = [ ":cc_ops", - "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -471,7 +531,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", "//tensorflow/core/kernels:ops_util", ], ) @@ -512,6 +571,7 @@ cc_library( tf_cc_test( name = "coordinator_test", + size = "small", srcs = ["training/coordinator_test.cc"], deps = [ ":cc_ops", diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index 644409203c1..ba056a8f3a8 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -16,32 +16,55 @@ limitations under the License. #include "tensorflow/cc/client/client_session.h" #include +#include #include #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { +class ClientSession::Impl { + private: + friend class ClientSession; + + Impl(Session* session, std::shared_ptr graph) + : session_(session), graph_(std::move(graph)) {} + + static SessionOptions MakeDefaultSessionOptions(const string& target); + Status MaybeExtendGraph() const; + + std::unique_ptr session_; + std::shared_ptr graph_; + + mutable mutex mu_; + mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0; +}; + ClientSession::ClientSession(const Scope& scope, const string& target) - : ClientSession(scope, MakeDefaultSessionOptions(target)) {} + : ClientSession(scope, Impl::MakeDefaultSessionOptions(target)) {} ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {} ClientSession::ClientSession(const Scope& scope, - const SessionOptions& session_options) - : graph_(scope.graph_as_shared_ptr()) { + const SessionOptions& session_options) { Session* new_session; Status status = NewSession(session_options, &new_session); TF_CHECK_OK(status) << status; - session_.reset(new_session); - CHECK_NOTNULL(session_.get()); + impl_.reset(new Impl(new_session, scope.graph_as_shared_ptr())); + CHECK_NOTNULL(impl()->session_.get()); } -SessionOptions ClientSession::MakeDefaultSessionOptions( - const string& target) const { +// Define destructor here so we can forward declare `Impl` in client_session.h. +// If we define a dtor in the header file or use the default dtor, +// unique_ptr needs the complete type. +ClientSession::~ClientSession() {} + +SessionOptions ClientSession::Impl::MakeDefaultSessionOptions( + const string& target) { SessionOptions options; options.env = Env::Default(); options.target = target; @@ -67,7 +90,7 @@ Status ClientSession::Run(const FeedType& inputs, nullptr); } -Status ClientSession::MaybeExtendGraph() const { +Status ClientSession::Impl::MaybeExtendGraph() const { mutex_lock l(mu_); int num_nodes = graph_->num_node_ids(); if (num_nodes > last_num_graph_nodes_) { @@ -90,16 +113,18 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, feeds.emplace_back(feed.first.name(), feed.second.tensor); } std::vector output_tensor_names; + output_tensor_names.reserve(fetch_outputs.size()); for (auto const& output : fetch_outputs) { output_tensor_names.push_back(output.name()); } std::vector target_node_names; + target_node_names.reserve(run_outputs.size()); for (auto const& output : run_outputs) { target_node_names.push_back(output.node()->name()); } - TF_RETURN_IF_ERROR(MaybeExtendGraph()); - return session_->Run(run_options, feeds, output_tensor_names, - target_node_names, outputs, run_metadata); + TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); + return impl()->session_->Run(run_options, feeds, output_tensor_names, + target_node_names, outputs, run_metadata); } } // end namespace tensorflow diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index 28ff3ec9641..5fb4109f7d1 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -23,14 +23,13 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { +/// @addtogroup core +/// @{ + /// A `ClientSession` object lets the caller drive the evaluation of the /// TensorFlow graph constructed with the C++ API. /// @@ -64,6 +63,8 @@ class ClientSession { /// Create a new session, configuring it with `session_options`. ClientSession(const Scope& scope, const SessionOptions& session_options); + ~ClientSession(); + /// Evaluate the tensors in `fetch_outputs`. The values are returned as /// `Tensor` objects in `outputs`. The number and order of `outputs` will /// match `fetch_outputs`. @@ -89,18 +90,14 @@ class ClientSession { // TODO(keveman): Add support for partial run. private: - SessionOptions MakeDefaultSessionOptions(const string& target) const; - Status MaybeExtendGraph() const; - - std::unique_ptr session_; - std::shared_ptr graph_; - - mutable mutex mu_; - mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0; - - TF_DISALLOW_COPY_AND_ASSIGN(ClientSession); + class Impl; + std::unique_ptr impl_; + Impl* impl() { return impl_.get(); } + const Impl* impl() const { return impl_.get(); } }; +/// @} + } // end namespace tensorflow #endif // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index 9c0f00f2b12..dfbac9788e1 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -49,7 +49,7 @@ TEST(ClientSessionTest, Feed) { TEST(ClientSessionTest, Extend) { Scope root = Scope::NewRootScope(); - auto a = Placeholder(root, DT_INT32); + auto a = Placeholder(root, DT_INT32, Placeholder::Shape({2})); auto c = Add(root, a, {2, 2}); ClientSession session(root); std::vector outputs; diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 5f85d8c5edf..71aa986f918 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -57,6 +57,16 @@ string GetPath(const string& dot_h_fname) { return result; } +// Converts: some/path/to/file.xx +// to: file +// (note that suffix is removed) +string GetFilename(const string& path) { + size_t slash_pos = path.rfind('/'); + if (slash_pos == path.npos) slash_pos = -1; + size_t dot_pos = path.rfind('.'); + return path.substr(slash_pos + 1, dot_pos - (slash_pos + 1)); +} + // Converts: // cc/ops/gen_foo_ops.h // to: @@ -77,6 +87,17 @@ string ToGuard(const string& path) { return guard; } +// Converts: some_name_xyz +// to: Some Name Xyz +string ToTitle(const string& name) { + string title = name; + for (int i = 0; i < title.size(); ++i) { + if (title[i] == '_') title[i] = ' '; + } + str_util::TitlecaseString(&title, " "); + return title; +} + // Change: Into: // ABC /// ABC // /// @@ -105,7 +126,11 @@ string PrintString(const string& str) { return strings::StrCat("\"", str_util::CEscape(str), "\""); } -string PrintTensorShape(const TensorShape& shape) { +string PrintTensorShape(const TensorShapeProto& shape_proto) { + PartialTensorShape shape(shape_proto); + if (shape.IsIdenticalTo(PartialTensorShape())) { + return "::tensorflow::PartialTensorShape() /* unknown */"; + } string ret = "{"; for (int d = 0; d < shape.dims(); ++d) { if (d > 0) strings::StrAppend(&ret, ", "); @@ -167,7 +192,13 @@ string PrintTensor(const TensorProto& tensor_proto) { } } -string PrintAttrValue(string op, const AttrValue& attr_value) { +string PrintTensorProto(const TensorProto& proto) { + return strings::StrCat("Input::Initializer(", "{", PrintTensor(proto), "}, ", + PrintTensorShape(proto.tensor_shape()), + ").AsTensorProto()"); +} + +string PrintAttrValue(const string& op, const AttrValue& attr_value) { switch (attr_value.value_case()) { case AttrValue::kS: return PrintString(attr_value.s()); @@ -182,12 +213,9 @@ string PrintAttrValue(string op, const AttrValue& attr_value) { case AttrValue::kType: return EnumName_DataType(attr_value.type()); case AttrValue::kShape: - return PrintTensorShape(TensorShape(attr_value.shape())); + return PrintTensorShape(attr_value.shape()); case AttrValue::kTensor: - return strings::StrCat( - "Input::Initializer(", "{", PrintTensor(attr_value.tensor()), "}, ", - PrintTensorShape(TensorShape(attr_value.tensor().tensor_shape())), - ").AsTensorProto()"); + return PrintTensorProto(attr_value.tensor()); case AttrValue::kList: { string ret = "{"; if (attr_value.list().s_size() > 0) { @@ -220,8 +248,14 @@ string PrintAttrValue(string op, const AttrValue& attr_value) { } else if (attr_value.list().shape_size() > 0) { for (int i = 0; i < attr_value.list().shape_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); - strings::StrAppend( - &ret, PrintTensorShape(TensorShape(attr_value.list().shape(i)))); + strings::StrAppend(&ret, + PrintTensorShape(attr_value.list().shape(i))); + } + } else if (attr_value.list().tensor_size() > 0) { + for (int i = 0; i < attr_value.list().tensor_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, + PrintTensorProto(attr_value.list().tensor(i))); } } strings::StrAppend(&ret, "}"); @@ -271,8 +305,8 @@ std::pair AttrTypeName(StringPiece attr_type) { {"list(bool)", {"gtl::ArraySlice", true}}, {"type", {"DataType", false}}, {"list(type)", {"DataTypeSlice", true}}, - {"shape", {"TensorShape", false}}, - {"list(shape)", {"gtl::ArraySlice", true}}, + {"shape", {"PartialTensorShape", false}}, + {"list(shape)", {"gtl::ArraySlice", true}}, {"tensor", {"TensorProto", true}}, {"list(tensor)", {"gtl::ArraySlice", true}}, {"func", {"NameAttrList", true}}, @@ -416,6 +450,7 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, } strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n"); + // Process inputs for (int i = 0; i < op_def.input_arg_size(); ++i) { const auto& arg(op_def.input_arg(i)); arg_types.push_back(strings::StrCat( @@ -430,30 +465,45 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, arg.description(), "\n"); } } + + // Process attrs + string required_attrs_comment; + string optional_attrs_comment; for (int i = 0; i < op_def.attr_size(); ++i) { const auto& attr(op_def.attr(i)); - // If the attr is going to be inferred or is optional, don't add it as a - // required argument. - if ((inferred_input_attrs.find(attr.name()) != - inferred_input_attrs.end()) || - attr.has_default_value()) { - continue; - } + // Skip inferred arguments + if (inferred_input_attrs.count(attr.name()) > 0) continue; + const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; + string attr_name = AvoidCPPKeywords(attr.name()); - arg_types.push_back(strings::StrCat(use_const ? "const " : "", - attr_type_name, use_const ? "&" : "")); - arg_names.push_back(AvoidCPPKeywords(attr.name())); + string attr_comment; if (!attr.description().empty()) { - strings::StrAppend(&comment, "* ", AvoidCPPKeywords(attr.name()), ":\n"); // TODO(keveman): Word wrap and indent this, to handle multi-line // descriptions. - strings::StrAppend(&comment, " ", attr.description(), "\n"); + strings::StrAppend(&attr_comment, "* ", attr_name, ": ", + attr.description(), "\n"); + } + if (attr.has_default_value()) { + strings::StrAppend(&optional_attrs_comment, attr_comment); + } else { + strings::StrAppend(&required_attrs_comment, attr_comment); + arg_types.push_back(strings::StrCat( + use_const ? "const " : "", attr_type_name, use_const ? "&" : "")); + arg_names.push_back(attr_name); } } + strings::StrAppend(&comment, required_attrs_comment); + + if (!optional_attrs_comment.empty()) { + strings::StrAppend(&comment, "\nOptional attributes (see `Attrs`):\n"); + strings::StrAppend(&comment, optional_attrs_comment); + } + + // Process outputs for (int i = 0; i < op_def.output_arg_size(); ++i) { const auto& arg = op_def.output_arg(i); bool is_list = ArgIsList(arg); @@ -509,8 +559,6 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; - string attrs_comment = - strings::StrCat("Optional attribute setters for ", op_name, " :\n\n"); for (int i = 0; i < op_def.attr_size(); ++i) { const auto& attr(op_def.attr(i)); @@ -531,13 +579,15 @@ string OpInfo::GetOpAttrStruct() const { strings::StrCat(camel_case_name, suffix, "(", use_const ? "const " : "", attr_type_name, use_const ? "&" : ""); - strings::StrAppend(&attrs_comment, attr_func_def, "): Defaults to ", - SummarizeAttrValue(attr.default_value()), "\n"); + string attr_comment; if (!attr.description().empty()) { - // TODO(keveman): Word wrap and indent this to handle multi-line - // description. - strings::StrAppend(&attrs_comment, " ", attr.description(), "\n"); + strings::StrAppend(&attr_comment, attr.description(), "\n\n"); } + strings::StrAppend(&attr_comment, "Defaults to ", + SummarizeAttrValue(attr.default_value()), "\n"); + attr_comment = MakeComment(attr_comment, " "); + + strings::StrAppend(&setters, attr_comment); strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n"); strings::StrAppend(&setters, " Attrs ret = *this;\n"); strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n"); @@ -552,6 +602,8 @@ string OpInfo::GetOpAttrStruct() const { return ""; } + string attrs_comment = + strings::StrCat("Optional attribute setters for ", op_name, "\n"); string struct_decl = MakeComment(attrs_comment, " "); strings::StrAppend(&struct_decl, " struct Attrs {\n"); strings::StrAppend(&struct_decl, setters, struct_fields); @@ -678,7 +730,7 @@ void OpInfo::GetOutput(string* out) const { // One output, no need for NameRangeMap if (is_list_output[0]) { strings::StrAppend(out, - " for (int64 i = 0; i < ret->num_outputs(); ++i)\n"); + " for (int32 i = 0; i < ret->num_outputs(); ++i)\n"); strings::StrAppend(out, " this->", output_names[0], ".push_back(Output(ret, i));\n"); } else { @@ -688,11 +740,10 @@ void OpInfo::GetOutput(string* out) const { return; } strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n"); - strings::StrAppend( - out, - " ::tensorflow::Status _status_ = " - "::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), " - "nullptr, &_outputs_range);\n"); + strings::StrAppend(out, + " ::tensorflow::Status _status_ = " + "::tensorflow::NameRangesForNode(*ret, ret->op_def(), " + "nullptr, &_outputs_range);\n"); strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str, ".UpdateStatus(_status_);\n", " return;\n"); strings::StrAppend(out, " }\n\n"); @@ -701,7 +752,7 @@ void OpInfo::GetOutput(string* out) const { const string arg_range = strings::StrCat( "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]"); if (is_list_output[i]) { - strings::StrAppend(out, " for (int64 i = ", arg_range, ".first; i < ", + strings::StrAppend(out, " for (int32 i = ", arg_range, ".first; i < ", arg_range, ".second; ++i)\n"); strings::StrAppend(out, " this->", output_names[i], ".push_back(Output(ret, i));\n"); @@ -841,6 +892,10 @@ namespace ops { )include", "#include \"", op_header, "\"\n", namespace_begin); + const string filename = GetFilename(dot_h_fname); + const string doxygen = strings::StrCat("/// @defgroup ", filename, " ", + ToTitle(filename), "\n", "/// @{\n\n"); + TF_CHECK_OK(h->Append( strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n" "#ifndef ", @@ -850,6 +905,7 @@ namespace ops { *op_header_guard, "\n\n"))); TF_CHECK_OK(h->Append(header)); TF_CHECK_OK(h->Append(namespace_begin)); + TF_CHECK_OK(h->Append(doxygen)); TF_CHECK_OK(cc->Append(cc_header)); } @@ -860,7 +916,9 @@ void FinishFiles(bool internal, WritableFile* h, WritableFile* cc, } // namespace tensorflow )footer" : - R"footer(} // namespace ops + R"footer(/// @} + +} // namespace ops } // namespace tensorflow )footer"; @@ -892,7 +950,7 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname, // Load the override map. OpGenOverrideMap override_map; if (!overrides_fnames.empty()) { - override_map.LoadFileList(env, overrides_fnames); + TF_CHECK_OK(override_map.LoadFileList(env, overrides_fnames)); } // Write the initial boilerplate to the .h and .cc files. diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index 6dc0d84c16d..5da23036eaa 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -32,10 +32,11 @@ Output Linear(const Scope& scope, Input x, Input w, Input b) { return BiasAdd(cop_scopes.last, m, b); } -void GetColocationConstraints(Output tensor, std::vector* constraints) { +void GetColocationConstraints(const Output& tensor, + std::vector* constraints) { constraints->clear(); - TF_EXPECT_OK( - GetNodeAttr(tensor.op().node()->def(), kColocationAttrName, constraints)); + TF_EXPECT_OK(GetNodeAttr(tensor.op().node()->attrs(), kColocationAttrName, + constraints)); } } // namespace @@ -158,11 +159,11 @@ TEST(CCOpTest, KernelLabel) { Scope root = Scope::NewRootScope(); auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f); TF_EXPECT_OK(root.status()); - const auto& attrs = add.z.op().node()->def().attr(); - ASSERT_TRUE(attrs.find("_kernel") != attrs.end()); - auto kernel_attr = attrs.find("_kernel")->second; - TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string")); - EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel"); + AttrSlice attrs = add.z.op().node()->attrs(); + const auto* kernel_attr = attrs.Find("_kernel"); + ASSERT_TRUE(kernel_attr); + TF_EXPECT_OK(AttrValueHasType(*kernel_attr, "string")); + EXPECT_EQ(kernel_attr->s(), "AddWithKernelLabel"); } TEST(CCOpTest, ColocateWith) { @@ -189,8 +190,7 @@ TEST(CCOpTest, ColocateWith) { Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4); auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7); - const auto& attrs = c6.op().node()->def().attr(); - EXPECT_TRUE(attrs.find("_class") == attrs.end()); + EXPECT_FALSE(c6.op().node()->attrs().Find("_class")); } TEST(CCOpTest, TemplatedConst) { diff --git a/tensorflow/cc/framework/grad_op_registry.cc b/tensorflow/cc/framework/grad_op_registry.cc index 0d6a377b507..254705736e7 100644 --- a/tensorflow/cc/framework/grad_op_registry.cc +++ b/tensorflow/cc/framework/grad_op_registry.cc @@ -32,7 +32,13 @@ bool GradOpRegistry::Register(const string& op, GradFunc func) { Status GradOpRegistry::Lookup(const string& op, GradFunc* func) const { auto iter = registry_.find(op); if (iter == registry_.end()) { - return errors::NotFound("No gradient defined for op: ", op); + const string error_msg = + "No gradient defined for op: " + op + + ". Please see " + "https://www.tensorflow.org/code/" + "tensorflow/cc/gradients/README.md" + " for instructions on how to add C++ gradients."; + return errors::NotFound(error_msg); } *func = iter->second; return Status::OK(); diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index 849a8eed6f2..f3a7c138c4e 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -22,8 +22,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { // TODO(andydavis) Support returning relative error (as opposed to max error) @@ -39,14 +37,16 @@ Status ComputeTheoreticalJacobianTranspose( const std::vector& x_shapes, const std::vector& x_datas, const OutputList& ys, const std::vector& y_shapes, - std::vector& jacobian_ts) { - int y_num = y_shapes.size(); - int x_num = x_shapes.size(); + std::vector* jacobian_ts) { + size_t y_num = y_shapes.size(); + size_t x_num = x_shapes.size(); // Call AddSymbolicGradients to get 'dxs' (we will feed 'dys'). OutputList dys; + dys.reserve(y_shapes.size()); for (const auto& y_shape : y_shapes) { // TODO(suharshs): This currently assumes that all x's are the same type. - dys.push_back(Cast(scope, Const(scope, 1.0, y_shape), xs[0].type())); + dys.push_back( + ops::Cast(scope, ops::Const(scope, 1.0, y_shape), xs[0].type())); } OutputList dxs; TF_RETURN_IF_ERROR(AddSymbolicGradients(scope, ys, xs, dys, &dxs)); @@ -84,7 +84,7 @@ Status ComputeTheoreticalJacobianTranspose( for (int x_idx = 0; x_idx < x_num; x_idx++) { const int64 x_size = x_shapes[x_idx].num_elements(); - auto jacobian = jacobian_ts[x_idx * y_num + y_idx].matrix(); + auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); auto dx_flat = dxout[x_idx].flat(); for (int r = 0; r < x_size; ++r) { jacobian(r, c) = dx_flat(r); @@ -97,20 +97,20 @@ Status ComputeTheoreticalJacobianTranspose( return Status::OK(); } -Status EvaluateGraph(ClientSession& session, const OutputList& xs, - const OutputList& ys, std::vector& x_datas, +Status EvaluateGraph(ClientSession* session, const OutputList& xs, + const OutputList& ys, std::vector* x_datas, std::vector* y_datas) { // Create the feed list. ClientSession::FeedType feed_list; - for (int i = 0; i < x_datas.size(); i++) { - feed_list.insert({xs[i], x_datas[i]}); + for (int i = 0; i < x_datas->size(); i++) { + feed_list.insert({xs[i], (*x_datas)[i]}); } - TF_RETURN_IF_ERROR(session.Run(feed_list, ys, y_datas)); + TF_RETURN_IF_ERROR(session->Run(feed_list, ys, y_datas)); for (int y_idx = 0; y_idx < y_datas->size(); y_idx++) { - for (int x_idx = 0; x_idx < x_datas.size(); x_idx++) { + for (int x_idx = 0; x_idx < x_datas->size(); x_idx++) { Tensor y_data = (*y_datas)[y_idx]; - if (y_data.SharesBufferWith(x_datas[x_idx])) { + if (y_data.SharesBufferWith((*x_datas)[x_idx])) { // Create copies of outputs that share a buffer with any inputs since // the underlying buffer of the input Tensors are not copied for some // operations (i.e. Identity), which can lead to incorrect results for @@ -128,14 +128,14 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, const OutputList& ys, const std::vector& y_shapes, const T delta, - std::vector& x_datas, - std::vector& jacobian_ts) { - int y_num = y_shapes.size(); - int x_num = x_shapes.size(); + std::vector* x_datas, + std::vector* jacobian_ts) { + size_t y_num = y_shapes.size(); + size_t x_num = x_shapes.size(); ClientSession session(scope); for (int x_idx = 0; x_idx < x_num; x_idx++) { - auto x_data_flat = x_datas[x_idx].flat(); + auto x_data_flat = (*x_datas)[x_idx].flat(); const int64 x_size = x_shapes[x_idx].num_elements(); // Compute the numeric Jacobian one column at a time by perturbing each @@ -147,11 +147,11 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, // Evaluate at positive delta. x_data_flat(r) = v + delta; std::vector y_pos; - TF_RETURN_IF_ERROR(EvaluateGraph(session, xs, ys, x_datas, &y_pos)); + TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_pos)); // Evaluate at negative delta. x_data_flat(r) = v - delta; std::vector y_neg; - TF_RETURN_IF_ERROR(EvaluateGraph(session, xs, ys, x_datas, &y_neg)); + TF_RETURN_IF_ERROR(EvaluateGraph(&session, xs, ys, x_datas, &y_neg)); for (int y_idx = 0; y_idx < y_num; y_idx++) { // Compute element-wise centered difference and store in each Jacobian. @@ -159,7 +159,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, auto y_neg_flat = y_neg[y_idx].flat(); const int64 y_size = y_shapes[y_idx].num_elements(); const T scale = 2 * delta; - auto jacobian = jacobian_ts[x_idx * y_num + y_idx].matrix(); + auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); for (int c = 0; c < y_size; ++c) { jacobian(r, c) = (y_pos_flat(c) - y_neg_flat(c)) / scale; } @@ -175,11 +175,11 @@ template void InitJacobians(const OutputList& xs, const std::vector& x_shapes, const std::vector& y_shapes, - std::vector& jacobians) { - int y_num = y_shapes.size(); - int x_num = x_shapes.size(); + std::vector* jacobians) { + size_t y_num = y_shapes.size(); + size_t x_num = x_shapes.size(); - jacobians.resize(y_num * x_num); + jacobians->resize(y_num * x_num); for (int x_idx = 0; x_idx < x_num; x_idx++) { const int64 x_size = x_shapes[x_idx].num_elements(); for (int y_idx = 0; y_idx < y_num; y_idx++) { @@ -187,7 +187,7 @@ void InitJacobians(const OutputList& xs, Tensor jacobian_t(xs[x_idx].type(), {x_size, y_size}); auto jacobian_t_flat = jacobian_t.flat(); jacobian_t_flat.setZero(); - jacobians[x_idx * y_num + y_idx] = std::move(jacobian_t); + (*jacobians)[x_idx * y_num + y_idx] = std::move(jacobian_t); } } } @@ -197,23 +197,23 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, const std::vector& x_shapes, const OutputList& ys, const std::vector& y_shapes, - std::vector& x_datas, + std::vector* x_datas, T* max_error) { // Initialize theoretical Jacobians to zeros. std::vector jacobian_ts; - InitJacobians(xs, x_shapes, y_shapes, jacobian_ts); + InitJacobians(xs, x_shapes, y_shapes, &jacobian_ts); // Compute theoretical Jacobian. TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose( - scope, xs, x_shapes, x_datas, ys, y_shapes, jacobian_ts)); + scope, xs, x_shapes, *x_datas, ys, y_shapes, &jacobian_ts)); // Initialize numeric Jacobian to zeros. std::vector jacobian_ns; - InitJacobians(xs, x_shapes, y_shapes, jacobian_ns); + InitJacobians(xs, x_shapes, y_shapes, &jacobian_ns); // Compute numeric Jacobian. TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose( - scope, xs, x_shapes, ys, y_shapes, 1e-3, x_datas, jacobian_ns)); + scope, xs, x_shapes, ys, y_shapes, 1e-3, x_datas, &jacobian_ns)); for (int i = 0; i < jacobian_ts.size(); i++) { // Compute the maximum error between theoretical and numeric Jacobians. @@ -256,7 +256,7 @@ Status ComputeGradientError(const Scope& scope, const OutputList& xs, } // Compute gradient error. return ComputeGradientErrorInternal(scope, xs, x_shapes, ys, y_shapes, - x_datas, max_error); + &x_datas, max_error); } template @@ -267,7 +267,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x, std::vector x_datas(1, Tensor(x_init_value)); // Compute gradient error. return ComputeGradientErrorInternal(scope, {x}, {x_datas[0].shape()}, {y}, - {y_shape}, x_datas, max_error); + {y_shape}, &x_datas, max_error); } #define INSTANTIATE_GRAD_ERR_TYPE(T) \ diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc index 998e9fe07dc..c5bddc50fcc 100644 --- a/tensorflow/cc/framework/gradient_checker_test.cc +++ b/tensorflow/cc/framework/gradient_checker_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/graph/equal_graph_def.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { using namespace ops; // NOLINT(build/namespaces) diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 2c60f947a55..8c00a6f7049 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -210,8 +210,8 @@ Status SymbolicGradientBuilder::Initialize() { { // Initialize backprop with `grad_inputs_`. - const int num_dy = grad_inputs_.size(); - for (int i = 0; i < num_dy; ++i) { + const size_t num_dy = grad_inputs_.size(); + for (size_t i = 0; i < num_dy; ++i) { TF_RETURN_IF_ERROR(BackpropAlongEdge(grad_inputs_[i], outputs_[i])); } } @@ -308,7 +308,7 @@ Status SymbolicGradientBuilder::AddGradients() { continue; } - const int num_no_grad = no_grad_dy_indices.size(); + const size_t num_no_grad = no_grad_dy_indices.size(); if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) { // No grad defined for this op, or all outputs returned 'NoGradient': // Backprop 'NoGradient' along the in edges. @@ -367,6 +367,19 @@ Status AddSymbolicGradients(const Scope& scope, return builder.AddGradients(); } +Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + std::vector* grad_outputs) { + std::vector grad_inputs; + grad_inputs.reserve(outputs.size()); + for (const Output& output : outputs) { + grad_inputs.emplace_back(ops::OnesLike(scope, output)); + } + return AddSymbolicGradients(scope, outputs, inputs, grad_inputs, + grad_outputs); +} + Output NoGradient() { return SymbolicGradientBuilder::NoGradient(); } } // end namespace tensorflow diff --git a/tensorflow/cc/framework/gradients.h b/tensorflow/cc/framework/gradients.h index d076bc43b4f..717f6f0636d 100644 --- a/tensorflow/cc/framework/gradients.h +++ b/tensorflow/cc/framework/gradients.h @@ -27,16 +27,19 @@ namespace tensorflow { /// derivatives of some loss function 'L' w.r.t 'outputs'), adds gradient nodes /// to the graph associated with 'scope', which compute (and return in /// 'grad_outputs') the symbolic partial derivatives of 'L' w.r.t 'inputs'. -/// - -// TODO(andydavis) Add overload of this function with no 'grad_inputs' arg. -// Implementation will fill in 'OnesLike' for all shapes in 'outputs'. Status AddSymbolicGradients(const Scope& scope, const std::vector& outputs, const std::vector& inputs, const std::vector& grad_inputs, std::vector* grad_outputs); +// Same as above, but uses 'OnesLike' for all shapes in +// 'outputs' as grad_inputs. +Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + std::vector* grad_outputs); + /// Returns a sentinel Output that represents 'no gradient' (i.e. no gradient /// flows along some graph edge during backpropagation). /// Can be returned in 'grad_outputs' by an invocation of 'AddSymbolicGradients' diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 8c6e5de4259..6a249825812 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/graph/equal_graph_def.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { using namespace ops; // NOLINT(build/namespaces) @@ -40,7 +40,7 @@ class GradientsTest : public ::testing::Test { TF_ASSERT_OK(scope_test_.ToGraphDef(&gdef_test)); GraphDef gdef_exp; TF_ASSERT_OK(scope_expected_.ToGraphDef(&gdef_exp)); - TF_EXPECT_GRAPH_EQ(gdef_test, gdef_exp); + TF_EXPECT_GRAPH_EQ(gdef_exp, gdef_test); } Scope scope_expected_; @@ -98,6 +98,32 @@ TEST_F(GradientsTest, OneMatMul) { CompareTestAndExpectedGraphs(); } +TEST_F(GradientsTest, OneMatMul_InferGradInputs) { + for (const bool expected : {false, true}) { + const Scope& scope = expected ? scope_expected_ : scope_test_; + // Construct forward graph. + auto x = Const(scope, {{1.0, 2.0}, {3.0, 4.0}}); + auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}}); + auto z = MatMul(scope, x, y); + TF_ASSERT_OK(scope.status()); + CHECK_NOTNULL(z.node()); + + if (expected) { + // Construct backward graph. + // The gradients function adds a OnesLike to create a dz of ones with the + // shape of z. + auto dz = OnesLike(scope, z); + auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true)); + auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true)); + } else { + // Call AddSymbolicGradients. + std::vector grad_outputs; + TF_ASSERT_OK(AddSymbolicGradients(scope, {z}, {x, y}, &grad_outputs)); + } + } + CompareTestAndExpectedGraphs(); +} + TEST_F(GradientsTest, TwoMatMuls_Chained) { for (const bool expected : {false, true}) { const Scope& scope = expected ? scope_expected_ : scope_test_; @@ -234,7 +260,7 @@ TEST_F(GradientsTest, StackUnstack_StopBackprop) { } TEST_F(GradientsTest, DependentGradOutputs) { - // Tests that dependant gradients (in this case the gradients w.r.t to the + // Tests that dependent gradients (in this case the gradients w.r.t to the // output and one input of MatMul) are computed properly. // Create two chained MatMul ops. diff --git a/tensorflow/cc/framework/ops.cc b/tensorflow/cc/framework/ops.cc index 50df891a4c4..920a8e79556 100644 --- a/tensorflow/cc/framework/ops.cc +++ b/tensorflow/cc/framework/ops.cc @@ -20,7 +20,7 @@ namespace tensorflow { Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {} -Output Operation::input(int i) const { +Output Operation::input(int32 i) const { CHECK_NOTNULL(node_); CHECK_GE(i, 0); CHECK_LT(i, node_->num_inputs()); @@ -37,14 +37,14 @@ Output Operation::input(int i) const { return Output(inputs_[i].first, inputs_[i].second); } -Output Operation::output(int i) const { +Output Operation::output(int32 i) const { CHECK_NOTNULL(node_); CHECK_GE(i, 0); CHECK_LT(i, node_->num_outputs()); return Output(node_, i); } -uint64 Operation::hash(int64 index) const { +uint64 Operation::hash(int32 index) const { return ::tensorflow::Hash64(reinterpret_cast(&node_), sizeof(Node*), index); } diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index d4f1079c3b2..8d4154220c4 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -26,30 +26,35 @@ limitations under the License. namespace tensorflow { +/// @defgroup core Core Tensorflow API + class Output; +/// @addtogroup core +/// @{ + /// Represents a node in the computation graph. class Operation { public: Operation() : node_(nullptr) {} explicit Operation(Node* n); - int num_inputs() const { return node_->num_inputs(); } - DataType input_type(int o) const { return node_->input_type(o); } - Output input(int i) const; + int32 num_inputs() const { return node_->num_inputs(); } + DataType input_type(int32 o) const { return node_->input_type(o); } + Output input(int32 i) const; - int num_outputs() const { return node_->num_outputs(); } - DataType output_type(int o) const { return node_->output_type(o); } - Output output(int i) const; + int32 num_outputs() const { return node_->num_outputs(); } + DataType output_type(int32 o) const { return node_->output_type(o); } + Output output(int32 i) const; Node* node() const { return node_; } - uint64 hash(int64 index) const; + uint64 hash(int32 index) const; bool operator==(const Operation& other) const { return node_ == other.node_; } private: - typedef std::vector> Inputs; + typedef std::vector> Inputs; static Inputs GetInputs(Node* node); Inputs inputs_; @@ -61,12 +66,12 @@ class Output { public: Output() = default; explicit Output(Node* n) : op_(n) {} - Output(Node* n, int64 index) : op_(n), index_(index) {} - Output(const Operation& op, int64 index) : op_(op), index_(index) {} + Output(Node* n, int32 index) : op_(n), index_(index) {} + Output(const Operation& op, int32 index) : op_(op), index_(index) {} Operation op() const { return op_; } Node* node() const { return op().node(); } - int64 index() const { return index_; } + int32 index() const { return index_; } DataType type() const { return op_.output_type(index_); } string name() const { return strings::StrCat(node()->name(), ":", index()); } bool operator==(const Output& other) const { @@ -77,13 +82,14 @@ class Output { private: Operation op_ = Operation(nullptr); - int64 index_ = 0; + int32 index_ = 0; }; +/// Hash class that can be used for e.g. storing Outputs in an unordered_map struct OutputHash { std::size_t operator()(const Output& output) const { return Hash64Combine(std::hash()(output.node()), - std::hash()(output.index())); + std::hash()(output.index())); } }; @@ -161,6 +167,7 @@ class Input { /// initializer list is indeed a valid multi-dimensional tensor. Initializer(const std::initializer_list& v); + // START_SKIP_DOXYGEN template ::value> struct RealType { typedef string type; @@ -170,6 +177,7 @@ class Input { struct RealType { typedef T type; }; + // END_SKIP_DOXYGEN TensorProto AsTensorProto() { TensorProto tensor_proto; @@ -222,12 +230,12 @@ class Input { /// Constructor specifying a node name, index and datatype. This should only /// be used for specifying a backward edge, needed by control flow. - Input(const string& name, int i, DataType dt) + Input(const string& name, int32 i, DataType dt) : node_name_(name), index_(i), data_type_(dt) {} Node* node() const { return output_.node(); } string node_name() const { return node_name_; } - int index() const { return node_name_.empty() ? output_.index() : index_; } + int32 index() const { return node_name_.empty() ? output_.index() : index_; } DataType data_type() const { return data_type_; } Status status() const { return status_; } const Tensor& tensor() const { return tensor_; } @@ -237,7 +245,7 @@ class Input { Output output_ = Output(Operation(nullptr), 0); Tensor tensor_; const string node_name_ = ""; - int index_ = 0; + int32 index_ = 0; DataType data_type_ = DT_INVALID; }; @@ -284,6 +292,8 @@ class InputList { std::vector inputs_; }; +/// @} + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_ diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index e1af5b36e8c..32c0822de69 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -16,15 +16,116 @@ limitations under the License. #include #include -#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" namespace tensorflow { -Scope::Scope(Graph* graph, Status* status, Scope::NameMap* name_map, - ShapeRefiner* refiner) +class Scope::Impl { + public: + // A NameMap is used to keep track of suffixes for names used in a scope. A + // name that has not been used so far in a scope will get no suffix. Later + // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes + // can share the same NameMap. For instance, a new scope created using + // WithControlDependencies() should would share the same NameMap with the + // parent. + typedef std::unordered_map NameMap; + + Impl(const std::shared_ptr& graph, + const std::shared_ptr& status, + const std::shared_ptr& name_map, + const std::shared_ptr& refiner); + + private: + friend class Scope; + + // Tag types to choose the constructor to dispatch. + struct Tags { + enum class ScopeName; + enum class OpName; + enum class ControlDeps; + enum class Device; + enum class SingleUseScope; + enum class ExitOnError; + enum class KernelLabel; + enum class Colocate; + }; + + Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner); + Impl(const Scope& other, Tags::ScopeName, const string& name, + bool copy_names); + Impl(const Scope& other, Tags::OpName, const string& name, + const string& op_name); + Impl(const Scope& other, Tags::ControlDeps, + std::vector control_deps, bool clear_control_deps); + Impl(const Scope& other, Tags::Device, const string& device); + Impl(const Scope& other, Tags::SingleUseScope, const string& op_name); + Impl(const Scope& other, Tags::ExitOnError); + Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label); + Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, + bool clear_colocations); + + std::unordered_set GetColocationConstraints( + const Operation& colocate_with_op) const; + + // Helper functions to get a unique names. + string GetUniqueName(const string& prefix, bool check_single_use) const; + string GetNameForOp(const string& default_name) const; + + bool single_use_scope() const { return scope_used_ != nullptr; } + + // The graph, status, and name maps are shared by all child scopes + // created from a single 'root' scope. A root scope is created by calling the + // Scope::NewRootScope function, which creates a new graph, a new status and + // the name maps. + std::shared_ptr graph_ = nullptr; + std::shared_ptr status_ = nullptr; + std::shared_ptr name_map_ = nullptr; + std::shared_ptr refiner_ = nullptr; + + // If scope_used_ is not nullptr, op_name_ should be empty and + // GetUniqueNameForOp can only be called once on this scope. More calls to + // GetUniqueNameForOp will cause an error status to be set on this scope. + std::shared_ptr scope_used_ = nullptr; + + const std::vector control_deps_; + + const string name_ = ""; + const string op_name_ = ""; + const bool exit_on_error_ = false; + const string kernel_label_ = ""; + const string device_ = ""; + const std::unordered_set colocation_constraints_; +}; + +Scope::Scope(Impl* impl) : impl_(impl) {} + +Scope::Scope(const Scope& other) : impl_(new Impl(*other.impl())) {} + +Scope::~Scope() {} + +Scope& Scope::operator=(const Scope& other) { + // We can't copy Impls because of the const members, use copy ctor instead + impl_.reset(new Impl(*other.impl_)); + return *this; +} + +Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, + ShapeRefiner* refiner) + : graph_(graph), + status_(status), + name_map_(name_map), + refiner_(refiner), + scope_used_(nullptr), + colocation_constraints_() {} + +Scope::Impl::Impl(const std::shared_ptr& graph, + const std::shared_ptr& status, + const std::shared_ptr& name_map, + const std::shared_ptr& refiner) : graph_(graph), status_(status), name_map_(name_map), @@ -34,143 +135,145 @@ Scope::Scope(Graph* graph, Status* status, Scope::NameMap* name_map, Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); - ShapeRefiner* refiner = new ShapeRefiner(graph->op_registry()); - return Scope(graph, new Status, new Scope::NameMap, refiner); + ShapeRefiner* refiner = + new ShapeRefiner(graph->versions().producer(), graph->op_registry()); + return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner)); } -Scope::Scope(const Scope& other, Scope::Tags::ScopeName, const string& name, - bool copy_names) - : graph_(other.graph_), - status_(other.status_), - name_map_(copy_names ? other.name_map_ +Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, + bool copy_names) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(copy_names ? other.impl()->name_map_ : std::shared_ptr(new NameMap)), - refiner_(other.refiner_), + refiner_(other.impl()->refiner_), scope_used_(nullptr), - control_deps_(other.control_deps_), + control_deps_(other.impl()->control_deps_), name_(name), op_name_(""), - exit_on_error_(other.exit_on_error_), - kernel_label_(other.kernel_label_), - device_(other.device_), - colocation_constraints_(other.colocation_constraints_) {} + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + colocation_constraints_(other.impl()->colocation_constraints_) {} -Scope::Scope(const Scope& other, Scope::Tags::OpName, const string& name, - const string& op_name) - : graph_(other.graph_), - status_(other.status_), - name_map_(other.name_map_), - refiner_(other.refiner_), - scope_used_(other.scope_used_), - control_deps_(other.control_deps_), +Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, + const string& op_name) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), name_(name), op_name_(op_name), - exit_on_error_(other.exit_on_error_), - kernel_label_(other.kernel_label_), - device_(other.device_), - colocation_constraints_(other.colocation_constraints_) {} + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + colocation_constraints_(other.impl()->colocation_constraints_) {} -Scope::Scope(const Scope& other, Scope::Tags::ControlDeps, - std::vector control_deps, bool clear_control_deps) - : graph_(other.graph_), - status_(other.status_), - name_map_(other.name_map_), - refiner_(other.refiner_), - scope_used_(other.scope_used_), - control_deps_(clear_control_deps - ? std::vector() - : (control_deps.insert(control_deps.begin(), - other.control_deps_.begin(), - other.control_deps_.end()), - control_deps)), - name_(other.name_), - op_name_(other.op_name_), - exit_on_error_(other.exit_on_error_), - kernel_label_(other.kernel_label_), - device_(other.device_), - colocation_constraints_(other.colocation_constraints_) {} +Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, + std::vector control_deps, bool clear_control_deps) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_( + clear_control_deps + ? std::vector() + : (control_deps.insert(control_deps.begin(), + other.impl()->control_deps_.begin(), + other.impl()->control_deps_.end()), + control_deps)), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + colocation_constraints_(other.impl()->colocation_constraints_) {} -Scope::Scope(const Scope& other, Scope::Tags::Device, const string& device) - : graph_(other.graph_), - status_(other.status_), - name_map_(other.name_map_), - refiner_(other.refiner_), - scope_used_(other.scope_used_), - control_deps_(other.control_deps_), - name_(other.name_), - op_name_(other.op_name_), - exit_on_error_(other.exit_on_error_), - kernel_label_(other.kernel_label_), +Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), device_(device), - colocation_constraints_(other.colocation_constraints_) {} + colocation_constraints_(other.impl()->colocation_constraints_) {} -Scope::Scope(const Scope& other, Scope::Tags::SingleUseScope, - const string& op_name) - : graph_(other.graph_), - status_(other.status_), - name_map_(other.name_map_), - refiner_(other.refiner_), +Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, + const string& op_name) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), scope_used_(new bool(false)), - control_deps_(other.control_deps_), - name_(other.name_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), op_name_(op_name), - exit_on_error_(other.exit_on_error_), - kernel_label_(other.kernel_label_), - device_(other.device_), - colocation_constraints_(other.colocation_constraints_) {} + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + colocation_constraints_(other.impl()->colocation_constraints_) {} -Scope::Scope(const Scope& other, Scope::Tags::ExitOnError) - : graph_(other.graph_), - status_(other.status_), - name_map_(other.name_map_), - refiner_(other.refiner_), - scope_used_(other.scope_used_), - control_deps_(other.control_deps_), - name_(other.name_), - op_name_(other.op_name_), +Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), exit_on_error_(true), - kernel_label_(other.kernel_label_), - device_(other.device_), - colocation_constraints_(other.colocation_constraints_) {} + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + colocation_constraints_(other.impl()->colocation_constraints_) {} -Scope::Scope(const Scope& other, Scope::Tags::KernelLabel, - const string& kernel_label) - : graph_(other.graph_), - status_(other.status_), - name_map_(other.name_map_), - refiner_(other.refiner_), - scope_used_(other.scope_used_), - control_deps_(other.control_deps_), - name_(other.name_), - op_name_(other.op_name_), - exit_on_error_(other.exit_on_error_), +Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, + const string& kernel_label) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), kernel_label_(kernel_label), - device_(other.device_), - colocation_constraints_(other.colocation_constraints_) {} + device_(other.impl()->device_), + colocation_constraints_(other.impl()->colocation_constraints_) {} -Scope::Scope(const Scope& other, Scope::Tags::Colocate, - const Operation& colocate_with_op, bool clear_colocations) - : graph_(other.graph_), - status_(other.status_), - name_map_(other.name_map_), - refiner_(other.refiner_), - scope_used_(other.scope_used_), - control_deps_(other.control_deps_), - name_(other.name_), - op_name_(other.op_name_), - exit_on_error_(other.exit_on_error_), - kernel_label_(other.kernel_label_), - device_(other.device_), +Scope::Impl::Impl(const Scope& other, Tags::Colocate, + const Operation& colocate_with_op, bool clear_colocations) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), colocation_constraints_( clear_colocations ? std::unordered_set() - : other.GetColocationConstraints(colocate_with_op)) {} + : other.impl()->GetColocationConstraints(colocate_with_op)) {} -std::unordered_set Scope::GetColocationConstraints( +std::unordered_set Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set current_constraints(colocation_constraints_); - const NodeDef& node_def = colocate_with_op.node()->def(); + const AttrSlice attrs = colocate_with_op.node()->attrs(); std::vector node_constraints; - if (GetNodeAttr(node_def, kColocationAttrName, &node_constraints).ok()) { + if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { for (const string& entry : node_constraints) { StringPiece s(entry); if (s.Consume(kColocationGroupPrefix)) { @@ -183,45 +286,59 @@ std::unordered_set Scope::GetColocationConstraints( return current_constraints; } +bool Scope::ok() const { return impl()->status_->ok(); } + +Graph* Scope::graph() const { return impl()->graph_.get(); } + +std::shared_ptr Scope::graph_as_shared_ptr() const { + return impl()->graph_; +} + +Status Scope::status() const { return *impl()->status_; } + +const std::vector& Scope::control_deps() const { + return impl()->control_deps_; +} + void Scope::UpdateStatus(const Status s) const { - status_->Update(s); - if (exit_on_error_ && !status_->ok()) { - LOG(FATAL) << *status_; + impl()->status_->Update(s); + if (impl()->exit_on_error_ && !ok()) { + LOG(FATAL) << *impl()->status_; } } Status Scope::ToGraphDef(GraphDef* gdef) const { - if (!status_->ok()) { - return *status_; + if (!ok()) { + return *impl()->status_; } graph()->ToGraphDef(gdef); return Status::OK(); } Status Scope::ToGraph(Graph* g) const { - if (status_->ok()) { + if (ok()) { GraphDef graph_def; graph()->ToGraphDef(&graph_def); GraphConstructorOptions opts; UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g)); } - return *status_; + return *impl()->status_; } void Scope::UpdateBuilder(NodeBuilder* builder) const { std::vector control_inputs; - for (const auto& op : control_deps_) { + for (const auto& op : impl()->control_deps_) { control_inputs.push_back(op.node()); } builder->ControlInputs(control_inputs); - if (!kernel_label_.empty()) { - builder->Attr("_kernel", kernel_label_); + if (!impl()->kernel_label_.empty()) { + builder->Attr("_kernel", impl()->kernel_label_); } - if (!colocation_constraints_.empty()) { - std::vector constraints(colocation_constraints_.begin(), - colocation_constraints_.end()); + if (!impl()->colocation_constraints_.empty()) { + std::vector constraints(impl()->colocation_constraints_.begin(), + impl()->colocation_constraints_.end()); // Sort the set. std::sort(constraints.begin(), constraints.end()); // Add loc:@ prefix @@ -231,12 +348,13 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const { }); builder->Attr(kColocationAttrName, constraints); } - if (!device_.empty()) { - builder->Device(device_); + if (!impl()->device_.empty()) { + builder->Device(impl()->device_); } } -string Scope::GetUniqueName(const string& prefix, bool check_single_use) const { +string Scope::Impl::GetUniqueName(const string& prefix, + bool check_single_use) const { if (check_single_use && single_use_scope()) { if (*scope_used_) { *status_ = @@ -256,7 +374,7 @@ string Scope::GetUniqueName(const string& prefix, bool check_single_use) const { return unique_name; } -string Scope::GetNameForOp(const string& default_name) const { +string Scope::Impl::GetNameForOp(const string& default_name) const { const string unique_name = GetUniqueName(default_name, true /* check_single_use */); const string sep = name_.empty() || unique_name.empty() ? "" : "/"; @@ -264,96 +382,125 @@ string Scope::GetNameForOp(const string& default_name) const { } string Scope::GetUniqueNameForOp(const string& default_name) const { - if (single_use_scope()) { - if (op_name_.empty() || *scope_used_) { - *status_ = + if (impl()->single_use_scope()) { + if (impl()->op_name_.empty() || *impl()->scope_used_) { + *impl()->status_ = errors::InvalidArgument("Cannot get a unique name in this scope"); return ""; } - *scope_used_ = true; - return op_name_; + *impl()->scope_used_ = true; + return impl()->op_name_; } - return op_name_.empty() ? GetNameForOp(default_name) : GetNameForOp(op_name_); + return impl()->op_name_.empty() ? impl()->GetNameForOp(default_name) + : impl()->GetNameForOp(impl()->op_name_); } Scope Scope::NewSubScope(const string& child_scope_name) const { if (child_scope_name.empty()) { - return Scope(*this, Scope::Tags::ScopeName(), name_, true /* copy_names */); + return Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->name_, + true /* copy_names */)); } const string unique_name = - GetUniqueName(child_scope_name, false /* check_single_use */); - const string sep = name_.empty() || unique_name.empty() ? "" : "/"; - return Scope(*this, Scope::Tags::ScopeName(), - strings::StrCat(name_, sep, unique_name), - false /* copy_names */); + impl()->GetUniqueName(child_scope_name, false /* check_single_use */); + const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/"; + return Scope(new Impl(*this, Impl::Tags::ScopeName(), + strings::StrCat(impl()->name_, sep, unique_name), + false /* copy_names */)); } Scope Scope::WithOpName(const string& op_name) const { - if (single_use_scope()) { + if (impl()->single_use_scope()) { UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name, " on this scope")); return *this; } - return Scope(*this, Scope::Tags::OpName(), name_, op_name); + return Scope(new Impl(*this, Impl::Tags::OpName(), impl()->name_, op_name)); } Scope Scope::WithControlDependencies( const gtl::ArraySlice& control_deps) const { - return Scope(*this, Scope::Tags::ControlDeps(), + return Scope( + new Impl(*this, Impl::Tags::ControlDeps(), std::vector(control_deps.begin(), control_deps.end()), - /* clear_control_deps */ false); + /* clear_control_deps */ false)); } Scope Scope::WithControlDependencies(const Output& control_dep) const { - return Scope(*this, Scope::Tags::ControlDeps(), - std::vector(1, control_dep.op()), - /* clear_control_deps */ false); + return Scope(new Impl(*this, Impl::Tags::ControlDeps(), + std::vector(1, control_dep.op()), + /* clear_control_deps */ false)); } Scope Scope::WithNoControlDependencies() const { - return Scope(*this, Scope::Tags::ControlDeps(), std::vector(), - /* clear_control_deps */ true); + return Scope(new Impl(*this, Impl::Tags::ControlDeps(), + std::vector(), + /* clear_control_deps */ true)); } Scope Scope::WithDevice(const string& device) const { - return Scope(*this, Scope::Tags::Device(), device); + return Scope(new Impl(*this, Impl::Tags::Device(), device)); } Scope Scope::ColocateWith(const Operation& op) const { - return Scope(*this, Scope::Tags::Colocate(), op, - /* clear_colocations */ false); + return Scope(new Impl(*this, Impl::Tags::Colocate(), op, + /* clear_colocations */ false)); } Scope Scope::ClearColocation() const { - return Scope(*this, Scope::Tags::Colocate(), Operation(), - /* clear_colocations */ true); + return Scope(new Impl(*this, Impl::Tags::Colocate(), Operation(), + /* clear_colocations */ true)); } Scope Scope::ExitOnError() const { - return Scope(*this, Scope::Tags::ExitOnError()); + return Scope(new Impl(*this, Impl::Tags::ExitOnError())); } Scope Scope::WithKernelLabel(const string& kernel_label) const { - return Scope(*this, Scope::Tags::KernelLabel(), kernel_label); + return Scope(new Impl(*this, Impl::Tags::KernelLabel(), kernel_label)); } CompositeOpScopes Scope::GetCompositeOpScopes( const string& composite_op_name) const { - if (op_name_.empty() && composite_op_name.empty()) { + if (impl()->op_name_.empty() && composite_op_name.empty()) { UpdateStatus(errors::InvalidArgument( "Cannot create composite op scopes with empty name")); return {*this, *this}; } - if (!single_use_scope()) { - Scope child = NewSubScope(op_name_.empty() ? composite_op_name : op_name_); - const string child_op_sep = name_.empty() ? "" : "_"; - return {child, Scope(child, Scope::Tags::SingleUseScope(), - strings::StrCat(name_, child_op_sep, child.name_))}; + if (!impl()->single_use_scope()) { + Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name + : impl()->op_name_); + const string child_op_sep = impl()->name_.empty() ? "" : "_"; + const string child_name = + strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_); + return {child, + Scope(new Impl(child, Impl::Tags::SingleUseScope(), child_name))}; } else { - return { - Scope(*this, Scope::Tags::ScopeName(), op_name_, true /* copy_names */), - *this}; + return {Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->op_name_, + true /* copy_names */)), + *this}; } } +class InternalScope { + public: + // NewScope doesn't take ownership of the inputs. + static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) { + Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap; + for (const Node* node : graph->nodes()) { + (*name_map)[node->name()] = 0; + } + // We provide null destructors for these shared ptrs (except for name_map) + // since the caller owns them and doesn't want the scope to destroy them. + return Scope(new Scope::Impl( + std::shared_ptr(graph, [](Graph*) {}), + std::shared_ptr(status, [](Status*) {}), + std::shared_ptr(name_map), + std::shared_ptr(refiner, [](ShapeRefiner*) {}))); + } +}; + +Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) { + return InternalScope::NewScope(graph, status, refiner); +} + } // namespace tensorflow diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 47d1026bb23..ec3543772d8 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -23,16 +23,19 @@ limitations under the License. #include #include "tensorflow/cc/framework/ops.h" -#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { +class Graph; class GraphDef; class NodeBuilder; struct CompositeOpScopes; +/// @addtogroup core +/// @{ + /// A `Scope` object represents a set of related TensorFlow ops that have the /// same properties such as a common name prefix. /// @@ -91,6 +94,10 @@ struct CompositeOpScopes; /// op-constructor functions on the same `Scope` object. class Scope { public: + Scope(const Scope& other); + ~Scope(); + Scope& operator=(const Scope& other); + // The following functions are for users making graphs. They return brand new // scopes, or scopes derived from an existing scope object. @@ -161,20 +168,21 @@ class Scope { // START_SKIP_DOXYGEN /// Update the builder with properties accumulated in this scope. + // TODO(skyewm): NodeBuilder is not part of public API void UpdateBuilder(NodeBuilder* builder) const; // END_SKIP_DOXYGEN CompositeOpScopes GetCompositeOpScopes(const string& composite_op_name) const; - bool ok() const { return status_->ok(); } + bool ok() const; - Graph* graph() const { return graph_.get(); } + // TODO(skyewm): Graph is not part of public API + Graph* graph() const; - ShapeRefiner* refiner() const { return refiner_.get(); } + // TODO(skyewm): Graph is not part of public API + std::shared_ptr graph_as_shared_ptr() const; - std::shared_ptr graph_as_shared_ptr() const { return graph_; } - - Status status() const { return *status_; } + Status status() const; /// If status() is Status::OK(), convert the Graph object stored in this scope /// to a GraphDef proto and return Status::OK(). Otherwise, return the error @@ -193,74 +201,15 @@ class Scope { Status ToGraph(Graph* g) const; // END_SKIP_DOXYGEN - const std::vector& control_deps() const { return control_deps_; } + const std::vector& control_deps() const; private: - // Tag types to choose the constructor to dispatch. - struct Tags { - enum class ScopeName; - enum class OpName; - enum class ControlDeps; - enum class Device; - enum class SingleUseScope; - enum class ExitOnError; - enum class KernelLabel; - enum class Colocate; - }; - - // A NameMap is used to keep track of suffixes for names used in a scope. A - // name that has not been used so far in a scope will get no suffix. Later - // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes - // can share the same NameMap. For instance, a new scope created using - // WithControlDependencies() should would share the same NameMap with the - // parent. - typedef std::unordered_map NameMap; - - Scope(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner); - Scope(const Scope& other, Tags::ScopeName, const string& name, - bool copy_names); - Scope(const Scope& other, Tags::OpName, const string& name, - const string& op_name); - Scope(const Scope& other, Tags::ControlDeps, - std::vector control_deps, bool clear_control_deps); - Scope(const Scope& other, Tags::Device, const string& device); - Scope(const Scope& other, Tags::SingleUseScope, const string& op_name); - Scope(const Scope& other, Tags::ExitOnError); - Scope(const Scope& other, Tags::KernelLabel, const string& kernel_label); - Scope(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, - bool clear_colocations); - - std::unordered_set GetColocationConstraints( - const Operation& colocate_with_op) const; - - // Helper functions to get a unique names. - string GetUniqueName(const string& prefix, bool check_single_use) const; - string GetNameForOp(const string& default_name) const; - - bool single_use_scope() const { return scope_used_ != nullptr; } - - // The graph, status, and name maps are shared by all child scopes - // created from a single 'root' scope. A root scope is created by calling the - // Scope::NewRootScope function, which creates a new graph, a new status and - // the name maps. - std::shared_ptr graph_ = nullptr; - std::shared_ptr status_ = nullptr; - std::shared_ptr name_map_ = nullptr; - std::shared_ptr refiner_ = nullptr; - - // If scope_used_ is not nullptr, op_name_ should be empty and - // GetUniqueNameForOp can only be called once on this scope. More calls to - // GetUniqueNameForOp will cause an error status to be set on this scope. - std::shared_ptr scope_used_ = nullptr; - - const std::vector control_deps_; - - const string name_ = ""; - const string op_name_ = ""; - const bool exit_on_error_ = false; - const string kernel_label_ = ""; - const string device_ = ""; - const std::unordered_set colocation_constraints_; + friend class InternalScope; + class Impl; + std::unique_ptr impl_; + Impl* impl() { return impl_.get(); } + const Impl* impl() const { return impl_.get(); } + explicit Scope(Impl*); }; /// A helper struct to hold the scopes that would be used by a function @@ -273,6 +222,8 @@ struct CompositeOpScopes { Scope last; }; +/// @} + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h new file mode 100644 index 00000000000..f2a911877f0 --- /dev/null +++ b/tensorflow/cc/framework/scope_internal.h @@ -0,0 +1,33 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ + +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { + +class ShapeRefiner; + +// NewInternalScope returns a new scope which doesn't take ownership of +// graph, status, name_map, and refiner. +// This is intended to enable the C API (which are used by other language +// bindings) to create a Scope and access C++ functionality (i.e. gradients). +Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ diff --git a/tensorflow/cc/framework/testutil.cc b/tensorflow/cc/framework/testutil.cc index b0746913a16..ca78f31db51 100644 --- a/tensorflow/cc/framework/testutil.cc +++ b/tensorflow/cc/framework/testutil.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/cc/framework/testutil.h" +#include + #include "tensorflow/cc/client/client_session.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/default_device.h" @@ -30,7 +32,7 @@ void GetTensors(const Scope& scope, OutputList tensors, void GetTensor(const Scope& scope, Output tensor, Tensor* out) { std::vector outputs; - GetTensors(scope, {tensor}, &outputs); + GetTensors(scope, {std::move(tensor)}, &outputs); *out = outputs[0]; } diff --git a/tensorflow/cc/gradients/README.md b/tensorflow/cc/gradients/README.md new file mode 100644 index 00000000000..3253163cc73 --- /dev/null +++ b/tensorflow/cc/gradients/README.md @@ -0,0 +1,52 @@ +# C++ gradients + +Gradients are currently being ported from +[python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/ops) +to C++ (in this directory). + +Contributions are welcome and much appreciated; please follow the instructions +below. + +1. Create the op gradient function in `foo_grad.cc` corresponding to the + `foo_grad.py` file where the op originated (i.e. `array_grad.py` op + gradients should be written in `array_grad.cc`). + +2. Write the op gradient with the following naming scheme: + + Status OpNameGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + ... + return scope.status(); + } + REGISTER_GRADIENT_OP("OpName", OpNameGrad); + +3. Ops gradients are implemented by using the [C++ + API](https://www.tensorflow.org/api_docs/cc/). + +4. Tests should be included in `foo_grad_test.cc`. Please see + [`array_grad_test.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/gradients/array_grad_test.cc) + for an many examples. Tests are as simple as, creating a placeholder input + for the op's inputs and calling `RunTest` (`RunTest` uses a [gradient + checker](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/framework/gradient_checker.cc) + to verify that the theoretical gradient matches the numeric gradient). For + example: + + TEST_F(ArrayGradTest, IdentityGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Identity(scope_, x); + RunTest(x, shape, y, shape); + } + +NOTE: There are some ops that require features from the C++ API that are not yet +implemented. + +* Ops that require PartialTensorShape information cannot yet be implemented. + +* Ops that require SparseTensor or IndexSlices (currently only in python) + cannot yet be implemented. + +* Maybe more. + +For questions: Please create an issue assigned to suharshs. diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index dd532369762..37f07e71a0d 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/gradients.h" @@ -42,9 +43,9 @@ Status PackGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int N; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N)); int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->reserve(N); auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis)); @@ -59,7 +60,7 @@ Status UnpackGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis))); return scope.status(); } @@ -89,6 +90,16 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); +Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + grad_outputs->push_back(Identity(scope, grad_inputs[0])); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); + Status SplitGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -150,9 +161,12 @@ REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad); Status CheckNumericsGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - grad_outputs->push_back(CheckNumerics( - scope, grad_inputs[0], - "Not a number (NaN) or infinity (Inf) values detected in gradient.")); + string message; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message)); + string err_msg = strings::StrCat( + "Not a number (NaN) or infinity (Inf) values detected in gradient. ", + message); + grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg)); return scope.status(); } REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad); @@ -201,9 +215,9 @@ Status ReverseSequenceGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { auto seq_lengths = op.input(1); int batch_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim)); int seq_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim)); grad_outputs->push_back( ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, ReverseSequence::BatchDim(batch_dim))); @@ -253,7 +267,8 @@ Status SpaceToBatchGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -276,7 +291,8 @@ Status BatchToSpaceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -299,7 +315,8 @@ Status SpaceToDepthGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -309,7 +326,8 @@ Status DepthToSpaceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -319,7 +337,7 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); @@ -332,7 +350,7 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); return scope.status(); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index aff06531395..71d9a8ed7be 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -21,6 +21,17 @@ namespace tensorflow { namespace ops { namespace { +// Conjugate helper function returns the conjugate of an Output if it +// is complex valued. +Output ConjugateHelper(const Scope& scope, const Output& out) { + DataType dtype = out.type(); + if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { + return Conj(scope, out); + } else { + return out; + } +} + // TODO(andydavis) Add control dependencies to gradient functions (as needed). Status AbsGrad(const Scope& scope, const Operation& op, @@ -44,9 +55,11 @@ REGISTER_GRADIENT_OP("Neg", NegGrad); Status InvGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // dx = dy * (-1 * (y * y)) + // dy/dx = -1/x^2 = -y^2 + auto dydx = Neg(scope, Square(scope, op.output(0))); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0))))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Inv", InvGrad); @@ -55,10 +68,12 @@ REGISTER_GRADIENT_OP("Reciprocal", InvGrad); Status SquareGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // dx = dy * (2 * x) + // dy/dx = (2 * x) auto two = Cast(scope, Const(scope, 2), op.input(0).type()); + auto dydx = Mul(scope, two, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0)))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Square", SquareGrad); @@ -68,11 +83,12 @@ Status SqrtGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = sqrt(x) // dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y) - // dx = dy * (0.5 * (1 / y)) auto y_inv = Reciprocal(scope, op.output(0)); auto half = Cast(scope, Const(scope, 0.5), op.input(0).type()); - auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv)); - grad_outputs->push_back(dx); + auto dydx = Mul(scope, half, y_inv); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); @@ -82,14 +98,14 @@ Status RsqrtGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = 1/x^1/2 = x^-1/2 // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1 - // dx = dy * (-1/2 * y * x^-1) auto x_inv = Reciprocal(scope, op.input(0)); auto y = op.output(0); auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type()); auto a = Mul(scope, neghalf, x_inv); - auto b = Mul(scope, a, y); - auto dx = Mul(scope, grad_inputs[0], b); - grad_outputs->push_back(dx); + auto dydx = Mul(scope, a, y); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); @@ -97,10 +113,11 @@ REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); Status ExpGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // y = exp(x) - // dy/dx = exp(x) - // dx = dy * y - grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0))); + // dy/dx = exp(x) = y + // grad(x) = grad(y) * conj(dy/dx) + // = grad(y) * conj(y) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); return scope.status(); } REGISTER_GRADIENT_OP("Exp", ExpGrad); @@ -108,10 +125,12 @@ REGISTER_GRADIENT_OP("Exp", ExpGrad); Status Expm1Grad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // f(x) = expm1(x) - // df/dx = exp(x) - // dx = dy * exp(x) - grad_outputs->push_back(Mul(scope, grad_inputs[0], Exp(scope, op.input(0)))); + // y = expm1(x) + // dy/dx = exp(x) + auto dydx = Exp(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Expm1", Expm1Grad); @@ -119,11 +138,12 @@ REGISTER_GRADIENT_OP("Expm1", Expm1Grad); Status LogGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // f(x) = log(x) = y - // df/dx = 1 / x - // dx = dy * (1 / x) + // y = log(x) + // dy/dx = 1 / x + auto dydx = Reciprocal(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Mul(scope, grad_inputs[0], Reciprocal(scope, op.input(0)))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Log", LogGrad); @@ -131,26 +151,54 @@ REGISTER_GRADIENT_OP("Log", LogGrad); Status Log1pGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // f(x) = log1p(x) = y - // df/dx = 1 / (1 + x) - // dx = dy * (1 / (1 + x)) + // y = log1p(x) + // dy/dx = 1 / (1 + x) auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Div(scope, grad_inputs[0], Add(scope, one, op.input(0)))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Log1p", Log1pGrad); +Status SinhGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = sinh(x) + // dy/dx = cosh(x) + auto dydx = Cosh(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Sinh", SinhGrad); + +Status CoshGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = cosh(x) + // dy/dx = sinh(x) + auto dydx = Sinh(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Cosh", CoshGrad); + Status TanhGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { // y = tanh(x) // dy/dx = 1 - (tanh(x))^2 = 1 - y^2 - // dx = dy * (1 - y^2) auto y2 = Square(scope, op.output(0)); auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); - auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2)); - grad_outputs->push_back(dx); + auto dydx = Sub(scope, one, y2); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Tanh", TanhGrad); @@ -160,11 +208,13 @@ Status SigmoidGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = 1 / (1 + exp(-x)) // dy/dx = y * (1 - y) - // dx = dy * y * (1 - y) auto y = op.output(0); auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); - auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y))); - grad_outputs->push_back(dx); + auto dydx = Mul(scope, y, Sub(scope, one, y)); + // dx = dy * y * (1 - y) + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); @@ -185,9 +235,10 @@ Status SinGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = sin(x) // dy/dx = cos(x) - // dx = dy * cos(x) - auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0))); - grad_outputs->push_back(dx); + auto dydx = Cos(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Sin", SinGrad); @@ -197,9 +248,10 @@ Status CosGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = cos(x) // dy/dx = -sin(x) - // dx = dy * -sin(x) - auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0)))); - grad_outputs->push_back(dx); + auto dydx = Neg(scope, Sin(scope, op.input(0))); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Cos", CosGrad); @@ -208,12 +260,12 @@ Status AsinGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { // y = asin(x) - // dy/dx = 1 / (1 - x * x)^1/2 - // dx = dy * (1 / (1 - x * x)^1/2) + // dy/dx = 1 / sqrt(1 - x^2) auto x2 = Square(scope, op.input(0)); auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); - auto dx = Mul(scope, grad_inputs[0], dydx); + // grad(x) = grad(y) * conj(dy/dx) + auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); grad_outputs->push_back(dx); return scope.status(); } @@ -239,9 +291,9 @@ Status TanGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = tan(x) // dy/dx = sec(x)^2 = 1 / cos(x)^2 - // dx = dy * (1 / cos(x)^2) auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); - auto dx = Mul(scope, grad_inputs[0], dydx); + // grad(x) = grad(y) * conj(dy/dx) + auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); grad_outputs->push_back(dx); return scope.status(); } @@ -324,7 +376,7 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, const string& attr_adj_x, const string& attr_adj_y, std::vector* grad_outputs) { DataType dtype; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { return errors::Unimplemented( "MatMul gradient for complex data type is not supported yet."); @@ -332,8 +384,10 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, bool ta; bool tb; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta)); - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); if (!ta && !tb) { return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index d7278929d46..1653b04378f 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -45,6 +45,8 @@ class CWiseUnaryGradTest : public ::testing::Test { EXPM1, LOG, LOG1P, + SINH, + COSH, TANH, SIGMOID, SIGN, @@ -56,23 +58,25 @@ class CWiseUnaryGradTest : public ::testing::Test { ATAN }; - void TestCWiseGrad(UnaryOpType op_type, std::function x_fn, - std::function dy_fn, - std::function dx_fn) { - Tensor x(DT_FLOAT, {2, 3, 2}); - auto x_flat = x.flat(); + template + void TestCWiseGrad(UnaryOpType op_type, const std::function& x_fn, + const std::function& dy_fn, + const std::function& dx_fn) { + DataType dtype = DataTypeToEnum::v(); + Tensor x(dtype, {2, 3, 2}); + auto x_flat = x.flat(); for (int i = 0; i < x_flat.size(); ++i) { x_flat(i) = x_fn(i); } - Tensor dy(DT_FLOAT, {2, 3, 2}); - auto dy_flat = dy.flat(); + Tensor dy(dtype, {2, 3, 2}); + auto dy_flat = dy.flat(); for (int i = 0; i < dy_flat.size(); ++i) { dy_flat(i) = dy_fn(x_flat(i)); } - Tensor dx(DT_FLOAT, {2, 3, 2}); - auto dx_flat = dx.flat(); + Tensor dx(dtype, {2, 3, 2}); + auto dx_flat = dx.flat(); for (int i = 0; i < dx_flat.size(); ++i) { dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); } @@ -109,6 +113,12 @@ class CWiseUnaryGradTest : public ::testing::Test { case LOG1P: y = Log1p(scope_, x); break; + case SINH: + y = Sinh(scope_, x); + break; + case COSH: + y = Cosh(scope_, x); + break; case TANH: y = Tanh(scope_, x); break; @@ -146,7 +156,19 @@ class CWiseUnaryGradTest : public ::testing::Test { test::ExpectClose(output, dx); } - float RV(std::vector v) { return v[random::New64() % v.size()]; } + float RV(const std::vector& v) { + return v[random::New64() % v.size()]; + } + + complex64 CRV(const std::vector& v) { + return v[random::New64() % v.size()]; + } + + complex64 conjugate(const complex64& val) { + return complex64(val.real(), -val.imag()); + } + + const complex64 one_{1.0, 0}; Scope scope_; }; @@ -155,14 +177,14 @@ TEST_F(CWiseUnaryGradTest, Abs) { auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return x * dy; }; - TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Neg) { auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return -dy; }; - TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn); + TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Reciprocal) { @@ -171,14 +193,36 @@ TEST_F(CWiseUnaryGradTest, Reciprocal) { auto dx_fn = [this](const float x, const float dy) { return -(1 / (x * x)) * dy; }; - TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); + TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64 x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64 x, const complex64 dy) { + return -conjugate(one_ / (x * x)) * dy; + }; + TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Square) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; - TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Square_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return conjugate(complex64(2, 0) * x) * dy; + }; + TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sqrt) { @@ -187,7 +231,18 @@ TEST_F(CWiseUnaryGradTest, Sqrt) { auto dx_fn = [this](const float x, const float dy) { return dy * 0.5 * (1.0 / std::sqrt(x)); }; - TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy; + }; + TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Rsqrt) { @@ -196,7 +251,18 @@ TEST_F(CWiseUnaryGradTest, Rsqrt) { auto dx_fn = [this](const float x, const float dy) { return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); }; - TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); + TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy; + }; + TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Exp) { @@ -205,7 +271,18 @@ TEST_F(CWiseUnaryGradTest, Exp) { auto dx_fn = [this](const float x, const float dy) { return dy * std::exp(x); }; - TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); + TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Exp_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::exp(x)); + }; + TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Expm1) { @@ -214,14 +291,36 @@ TEST_F(CWiseUnaryGradTest, Expm1) { auto dx_fn = [this](const float x, const float dy) { return dy * std::exp(x); }; - TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); + TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Expm1_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::exp(x)); + }; + TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Log) { auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; - TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); + TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Log_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(one_ / x); + }; + TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Log1p) { @@ -230,7 +329,64 @@ TEST_F(CWiseUnaryGradTest, Log1p) { auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / (1.0 + x)); }; - TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); + TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Log1p_Complex) { + auto x_fn = [this](const int i) { + return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / (one_ + conjugate(x)); + }; + TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sinh) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * std::cosh(x); + }; + TestCWiseGrad(SINH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sinh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::cosh(x)); + }; + TestCWiseGrad(SINH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Cosh) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * std::sinh(x); + }; + TestCWiseGrad(COSH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Cosh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::sinh(x)); + }; + TestCWiseGrad(COSH, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Tanh) { @@ -240,7 +396,21 @@ TEST_F(CWiseUnaryGradTest, Tanh) { const float y = std::tanh(x); return dy * (1.0 - y * y); }; - TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); + TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Tanh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + const complex64 y = std::tanh(x); + return dy * conjugate((one_ - y * y)); + }; + TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sigmoid) { @@ -250,14 +420,28 @@ TEST_F(CWiseUnaryGradTest, Sigmoid) { const float y = 1.0 / (1.0 + std::exp(-x)); return dy * y * (1.0 - y); }; - TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + const complex64 y = one_ / (one_ + std::exp(-x)); + return dy * conjugate(y * (one_ - y)); + }; + TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sign) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return 0.0; }; - TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sin) { @@ -266,7 +450,20 @@ TEST_F(CWiseUnaryGradTest, Sin) { auto dx_fn = [this](const float x, const float dy) { return dy * std::cos(x); }; - TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sin_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::cos(x)); + }; + TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Cos) { @@ -275,7 +472,20 @@ TEST_F(CWiseUnaryGradTest, Cos) { auto dx_fn = [this](const float x, const float dy) { return dy * -1.0 * std::sin(x); }; - TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); + TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Cos_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(-std::sin(x)); + }; + TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Asin) { @@ -284,7 +494,24 @@ TEST_F(CWiseUnaryGradTest, Asin) { auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / std::sqrt(1.0 - x * x)); }; - TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Asin_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / conjugate(std::sqrt(one_ - x * x)); + }; + // TODO(kbsriram) + // Enable test when the asin kernel supports complex numbers + if (false) { + TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); + } } TEST_F(CWiseUnaryGradTest, Acos) { @@ -293,7 +520,24 @@ TEST_F(CWiseUnaryGradTest, Acos) { auto dx_fn = [this](const float x, const float dy) { return dy * (-1.0 / std::sqrt(1.0 - x * x)); }; - TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Acos_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / -conjugate(std::sqrt(one_ - x * x)); + }; + // TODO(kbsriram) + // Add test when the acos kernel supports complex numbers + if (false) { + TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); + } } TEST_F(CWiseUnaryGradTest, Tan) { @@ -303,7 +547,25 @@ TEST_F(CWiseUnaryGradTest, Tan) { const float cosx = std::cos(x); return dy * (1 / (cosx * cosx)); }; - TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Tan_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + const complex64 cosx = std::cos(x); + return dy / conjugate(cosx * cosx); + }; + // TODO(kbsriram) + // Enable when tan kernel supports complex inputs + if (false) { + TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); + } } TEST_F(CWiseUnaryGradTest, Atan) { @@ -312,7 +574,24 @@ TEST_F(CWiseUnaryGradTest, Atan) { auto dx_fn = [this](const float x, const float dy) { return dy * (1 / (1 + x * x)); }; - TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Atan_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / (one_ + x * x); + }; + // TODO(kbsriram) + // Add test when the atan kernel supports complex numbers + if (false) { + TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); + } } class CWiseUnaryComplexGradTest : public ::testing::Test { diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index 8976a24edc6..e8cb6cf1dd1 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -23,6 +23,9 @@ limitations under the License. namespace tensorflow { namespace ops { +/// @defgroup const_op Const Op +/// @{ + Output Const(const Scope& scope, const Input::Initializer& val); NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); @@ -70,6 +73,8 @@ Output Const(const Scope& scope, const std::initializer_list& v, std::vector AsNodeOutList(const Scope& scope, const InputList& inp); +/// }@ + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index 5a4770f879f..3184edeb330 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -28,9 +28,9 @@ void ExpectNodeEqual(const Node* n, gtl::ArraySlice values, TensorShape shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(tensor.dtype(), dtype); test::ExpectTensorEqual(tensor, test::AsTensor(values, shape)); } @@ -39,9 +39,9 @@ void ExpectTypeAndShape(const Node* n, DataType expected_dtype, TensorShape expected_shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(dtype, expected_dtype); EXPECT_EQ(expected_shape, TensorShape(tensor.shape())); } diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt index 9e85e67cf5d..1dffb10c033 100644 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt @@ -22,7 +22,7 @@ op { name: "Where" input_rename: { from: "input" to: "condition" } } op { name: "ThreadUnsafeUnigramCandidateSampler", skip: true } # control_flow_ops -# TODO(josh11b): Hide Switch and Merge once we write and migrate users to +# TODO(joshl): Hide Switch and Merge once we write and migrate users to # a Cond() API. #op { name: "Switch" hide: true } #op { name: "Merge" hide: true } @@ -150,6 +150,12 @@ op { name: "Neg" rename_to: "Negate" alias: "Neg" } op { name: "Prod" alias: "ReduceProd" input_rename: { from: "reduction_indices" to: "axis" } } op { name: "Sub" rename_to: "Subtract" alias: "Sub" } op { name: "Sum" alias: "ReduceSum" input_rename: { from: "reduction_indices" to: "axis" } } +op { name: "SigmoidGrad" hide: true } +op { name: "TanhGrad" hide: true } +op { name: "InvGrad" hide: true } +op { name: "ReciprocalGrad" hide: true } +op { name: "SqrtGrad" hide: true } +op { name: "RsqrtGrad" hide: true } # *Grad ops get hidden, only for use by the gradient code. op { name: "SigmoidGrad" hide: true } diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 36fec7a2f2e..1cc7cf3f202 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -9,7 +9,14 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow:tensorflow.bzl", + "if_android", + "if_ios", + "if_mobile", + "if_not_mobile", + "tf_cc_test", +) cc_library( name = "constants", @@ -28,17 +35,33 @@ cc_library( cc_library( name = "loader", + hdrs = ["loader.h"], + deps = [ + ":loader_lite", + ] + if_not_mobile([ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ]) + if_android([ + "//tensorflow/core:android_tensorflow_lib", + ]), +) + +cc_library( + name = "loader_lite", srcs = ["loader.cc"], hdrs = ["loader.h"], deps = [ ":constants", + ] + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", "//tensorflow/core/util/tensor_bundle:naming", - ], + # mobile not supported yet + ]), ) tf_cc_test( @@ -66,6 +89,7 @@ filegroup( name = "saved_model_half_plus_two", srcs = glob([ "testdata/half_plus_two_pbtxt/**", + "testdata/half_plus_two_main_op/**", "testdata/half_plus_two/**", ]), ) diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 7f2d5609780..94a3b3cf465 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -33,6 +33,9 @@ constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; /// SavedModel legacy init op key. constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; +/// SavedModel main op key. +constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; + /// Directory in which to save the SavedModel variables. constexpr char kSavedModelVariablesDirectory[] = "variables"; diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 2acf9bf777a..807f5904afc 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/saved_model.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/tensor_bundle/naming.h" @@ -36,7 +37,7 @@ auto* load_attempt_count = monitoring::Counter<2>::New( "status"); auto* load_latency = monitoring::Counter<1>::New( "/tensorflow/cc/saved_model/load_latency", - "Latency in microseconds for SavedModels that were succesfully loaded.", + "Latency in microseconds for SavedModels that were successfully loaded.", "model_path"); constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; @@ -106,6 +107,37 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir, } } +bool HasMainOp(const MetaGraphDef& meta_graph_def) { + const auto& collection_def_map = meta_graph_def.collection_def(); + if (collection_def_map.find(kSavedModelMainOpKey) != + collection_def_map.end()) { + return true; + } + return false; +} + +Status RunMainOp(const RunOptions& run_options, const string& export_dir, + const MetaGraphDef& meta_graph_def, + const std::vector& asset_file_defs, + Session* session) { + LOG(INFO) << "Running MainOp on SavedModel bundle."; + const auto& collection_def_map = meta_graph_def.collection_def(); + const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey); + if (main_op_it != collection_def_map.end()) { + if (main_op_it->second.node_list().value_size() != 1) { + return errors::FailedPrecondition( + strings::StrCat("Expected exactly one main op in : ", export_dir)); + } + std::vector> inputs; + AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); + RunMetadata run_metadata; + const StringPiece main_op_name = main_op_it->second.node_list().value(0); + return session->Run(run_options, inputs, {}, {main_op_name.ToString()}, + nullptr /* outputs */, &run_metadata); + } + return Status::OK(); +} + Status RunRestore(const RunOptions& run_options, const string& export_dir, const StringPiece restore_op_name, const StringPiece variable_filename_const_op_name, @@ -121,8 +153,9 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, const string variables_index_path = io::JoinPath( variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { - return errors::NotFound( - "Checkpoint index file not found in SavedModel directory."); + LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " + "were restored."; + return Status::OK(); } const string variables_path = io::JoinPath(variables_directory, kSavedModelVariablesFilename); @@ -210,11 +243,15 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(), asset_file_defs, bundle->session.get())); - // TODO(sukritiramesh): Add support for a single main op to run upon load, - // which will supersede the legacy_init_op and separate RunRestore. - TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir, - bundle->meta_graph_def, asset_file_defs, - bundle->session.get())); + if (HasMainOp(bundle->meta_graph_def)) { + TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, + bundle->meta_graph_def, asset_file_defs, + bundle->session.get())); + } else { + TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir, + bundle->meta_graph_def, asset_file_defs, + bundle->session.get())); + } return Status::OK(); } diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index 9b9abdbb1f4..3d634dd5154 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -36,7 +36,7 @@ struct SavedModelBundle { /// resource leaks, we explicitly call Close on Sessions that we create. ~SavedModelBundle() { if (session) { - session->Close(); + session->Close().IgnoreError(); } } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 2a8a7c5bff6..cef29e7b071 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -31,6 +31,8 @@ namespace { constexpr char kTestDataPbTxt[] = "cc/saved_model/testdata/half_plus_two_pbtxt/00000123"; +constexpr char kTestDataMainOp[] = + "cc/saved_model/testdata/half_plus_two_main_op/00000123"; constexpr char kTestDataSharded[] = "cc/saved_model/testdata/half_plus_two/00000123"; @@ -165,6 +167,18 @@ TEST_F(LoaderTest, PbtxtFormat) { CheckSavedModelBundle(export_dir, bundle); } +TEST_F(LoaderTest, MainOpFormat) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataMainOp); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + TEST_F(LoaderTest, InvalidExportPath) { SavedModelBundle bundle; RunOptions run_options; diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two/00000123/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two/00000123/saved_model.pb index eeac8b12063..4a4fd025d9d 100755 Binary files a/tensorflow/cc/saved_model/testdata/half_plus_two/00000123/saved_model.pb and b/tensorflow/cc/saved_model/testdata/half_plus_two/00000123/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt new file mode 100644 index 00000000000..f9ff0366880 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/assets/foo.txt @@ -0,0 +1 @@ +asset-file-contents \ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb new file mode 100644 index 00000000000..daa272aead0 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..15b75d6ef6b Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index new file mode 100644 index 00000000000..7ec9fb4fe2d Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_main_op/00000123/variables/variables.index differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/00000123/saved_model.pbtxt b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/00000123/saved_model.pbtxt index 356dbe6eca6..9d7813a0a16 100755 --- a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/00000123/saved_model.pbtxt +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/00000123/saved_model.pbtxt @@ -284,6 +284,7 @@ meta_graphs { type: "shape" default_value { shape { + unknown_rank: true } } } @@ -447,7 +448,7 @@ meta_graphs { } } tags: "serve" - tensorflow_version: "0.12.head" + tensorflow_version: "1.1.0-rc2" tensorflow_git_version: "unknown" } graph_def { @@ -885,6 +886,7 @@ meta_graphs { key: "shape" value { shape { + unknown_rank: true } } } @@ -1714,7 +1716,7 @@ meta_graphs { dtype: DT_STRING tensor_shape { } - string_val: "_temp_aeab824a1fc94305a10a2504f5995de2/part" + string_val: "_temp_80e928f1e0c844239d136d1ea966099d/part" } } } @@ -2444,7 +2446,7 @@ meta_graphs { input: "^save/restore_shard" } versions { - producer: 21 + producer: 23 } } saver_def { @@ -2503,6 +2505,42 @@ meta_graphs { } } } + signature_def { + key: "classify_x2_to_y3" + value { + inputs { + key: "inputs" + value { + name: "x2:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + outputs { + key: "scores" + value { + name: "y3:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + method_name: "tensorflow/serving/classify" + } + } signature_def { key: "classify_x_to_y" value { diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc index 53a566db950..4511d043206 100644 --- a/tensorflow/cc/training/coordinator.cc +++ b/tensorflow/cc/training/coordinator.cc @@ -31,8 +31,8 @@ Coordinator::Coordinator(const std::vector& clean_stop_errors) } Coordinator::~Coordinator() { - RequestStop(); - Join(); + RequestStop().IgnoreError(); + Join().IgnoreError(); } Status Coordinator::RegisterRunner(std::unique_ptr runner) { @@ -115,4 +115,15 @@ void Coordinator::WaitForStop() { } } -} // namespace +Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const { + mutex_lock l(runners_lock_); + for (auto& t : runners_) { + Status s = t->ExportCostGraph(cost_graph); + if (!s.ok()) { + return s; + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index dbcf0720150..0e01b19cd98 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -21,19 +21,24 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { -/// The abstract interface for runners which must implement the Join function. +/// The abstract interface for runners which must implement the Join and the +/// IsRunning function. class RunnerInterface { public: virtual ~RunnerInterface() {} virtual Status Join() = 0; - + virtual Status ExportCostGraph(CostGraphDef* cost_graph) const { + return Status(error::INVALID_ARGUMENT, "No cost model to export."); + } /// Returns true iff the runner is running, i.e. if it is trying to populate /// its queue. virtual bool IsRunning() const = 0; @@ -101,6 +106,9 @@ class Coordinator { /// RequestStop() is called. void WaitForStop(); + // Returns the cost graph from stored run metadata in registered runners. + Status ExportCostGraph(CostGraphDef* cost_graph) const; + private: std::unordered_set clean_stop_errors_; condition_variable wait_for_stop_; @@ -111,12 +119,10 @@ class Coordinator { mutex status_lock_; Status status_ GUARDED_BY(status_lock_); - mutex runners_lock_; + mutable mutex runners_lock_; std::vector> runners_ GUARDED_BY(runners_lock_); - std::atomic num_runners_to_cancel_; - TF_DISALLOW_COPY_AND_ASSIGN(Coordinator); }; diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc index 5e4a6966901..48874033841 100644 --- a/tensorflow/cc/training/coordinator_test.cc +++ b/tensorflow/cc/training/coordinator_test.cc @@ -29,9 +29,10 @@ namespace { using error::Code; -void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) { +void WaitForStopThread(Coordinator* coord, Notification* about_to_wait, + Notification* done) { + about_to_wait->Notify(); coord->WaitForStop(); - *stopped = true; done->Notify(); } @@ -39,22 +40,22 @@ TEST(CoordinatorTest, TestStopAndWaitOnStop) { Coordinator coord; EXPECT_EQ(coord.ShouldStop(), false); - bool stopped = false; + Notification about_to_wait; Notification done; Env::Default()->SchedClosure( - std::bind(&WaitForStopThread, &coord, &stopped, &done)); - Env::Default()->SleepForMicroseconds(10000000); - EXPECT_EQ(stopped, false); + std::bind(&WaitForStopThread, &coord, &about_to_wait, &done)); + about_to_wait.WaitForNotification(); + Env::Default()->SleepForMicroseconds(1000 * 1000); + EXPECT_FALSE(done.HasBeenNotified()); - coord.RequestStop(); + TF_EXPECT_OK(coord.RequestStop()); done.WaitForNotification(); - EXPECT_EQ(stopped, true); - EXPECT_EQ(coord.ShouldStop(), true); + EXPECT_TRUE(coord.ShouldStop()); } class MockQueueRunner : public RunnerInterface { public: - MockQueueRunner(Coordinator* coord) { + explicit MockQueueRunner(Coordinator* coord) { coord_ = coord; join_counter_ = nullptr; thread_pool_.reset(new thread::ThreadPool(Env::Default(), "test-pool", 10)); @@ -66,17 +67,19 @@ class MockQueueRunner : public RunnerInterface { join_counter_ = join_counter; } - void StartCounting(std::atomic* counter, int until) { + void StartCounting(std::atomic* counter, int until, + Notification* start = nullptr) { thread_pool_->Schedule( - std::bind(&MockQueueRunner::CountThread, this, counter, until)); + std::bind(&MockQueueRunner::CountThread, this, counter, until, start)); } - void StartSettingStatus(const Status& status, BlockingCounter* counter) { - thread_pool_->Schedule( - std::bind(&MockQueueRunner::SetStatusThread, this, status, counter)); + void StartSettingStatus(const Status& status, BlockingCounter* counter, + Notification* start) { + thread_pool_->Schedule(std::bind(&MockQueueRunner::SetStatusThread, this, + status, counter, start)); } - Status Join() { + Status Join() override { if (join_counter_ != nullptr) { (*join_counter_)++; } @@ -93,15 +96,17 @@ class MockQueueRunner : public RunnerInterface { void Stop() { stopped_ = true; } private: - void CountThread(std::atomic* counter, int until) { + void CountThread(std::atomic* counter, int until, Notification* start) { + if (start != nullptr) start->WaitForNotification(); while (!coord_->ShouldStop() && counter->load() < until) { (*counter)++; - Env::Default()->SleepForMicroseconds(100000); + Env::Default()->SleepForMicroseconds(10 * 1000); } - coord_->RequestStop(); + coord_->RequestStop().IgnoreError(); } - void SetStatusThread(const Status& status, BlockingCounter* counter) { - Env::Default()->SleepForMicroseconds(100000); + void SetStatusThread(const Status& status, BlockingCounter* counter, + Notification* start) { + start->WaitForNotification(); SetStatus(status); counter->DecrementCount(); } @@ -118,19 +123,19 @@ TEST(CoordinatorTest, TestRealStop) { std::unique_ptr qr1(new MockQueueRunner(&coord)); qr1->StartCounting(&counter, 100); - coord.RegisterRunner(std::move(qr1)); + TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1))); std::unique_ptr qr2(new MockQueueRunner(&coord)); qr2->StartCounting(&counter, 100); - coord.RegisterRunner(std::move(qr2)); + TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2))); // Wait until the counting has started while (counter.load() == 0) ; - coord.RequestStop(); + TF_EXPECT_OK(coord.RequestStop()); int temp_counter = counter.load(); - Env::Default()->SleepForMicroseconds(10000000); + Env::Default()->SleepForMicroseconds(1000 * 1000); EXPECT_EQ(temp_counter, counter.load()); TF_EXPECT_OK(coord.Join()); } @@ -138,12 +143,14 @@ TEST(CoordinatorTest, TestRealStop) { TEST(CoordinatorTest, TestRequestStop) { Coordinator coord; std::atomic counter(0); + Notification start; std::unique_ptr qr; for (int i = 0; i < 10; i++) { qr.reset(new MockQueueRunner(&coord)); - qr->StartCounting(&counter, 10); - coord.RegisterRunner(std::move(qr)); + qr->StartCounting(&counter, 10, &start); + TF_ASSERT_OK(coord.RegisterRunner(std::move(qr))); } + start.Notify(); coord.WaitForStop(); EXPECT_EQ(coord.ShouldStop(), true); @@ -156,41 +163,43 @@ TEST(CoordinatorTest, TestJoin) { int join_counter = 0; std::unique_ptr qr1( new MockQueueRunner(&coord, &join_counter)); - coord.RegisterRunner(std::move(qr1)); + TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1))); std::unique_ptr qr2( new MockQueueRunner(&coord, &join_counter)); - coord.RegisterRunner(std::move(qr2)); + TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2))); - coord.RequestStop(); + TF_EXPECT_OK(coord.RequestStop()); TF_EXPECT_OK(coord.Join()); EXPECT_EQ(join_counter, 2); } TEST(CoordinatorTest, StatusReporting) { Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE}); + Notification start; BlockingCounter counter(3); std::unique_ptr qr1(new MockQueueRunner(&coord)); - qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter); - coord.RegisterRunner(std::move(qr1)); + qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter, &start); + TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1))); std::unique_ptr qr2(new MockQueueRunner(&coord)); - qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter); - coord.RegisterRunner(std::move(qr2)); + qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter, &start); + TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2))); std::unique_ptr qr3(new MockQueueRunner(&coord)); - qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter); - coord.RegisterRunner(std::move(qr3)); + qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter, &start); + TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3))); + start.Notify(); counter.Wait(); - coord.RequestStop(); + TF_EXPECT_OK(coord.RequestStop()); EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT); } TEST(CoordinatorTest, JoinWithoutStop) { Coordinator coord; std::unique_ptr qr(new MockQueueRunner(&coord)); - coord.RegisterRunner(std::move(qr)); + TF_ASSERT_OK(coord.RegisterRunner(std::move(qr))); EXPECT_EQ(coord.Join().code(), Code::FAILED_PRECONDITION); } @@ -198,7 +207,7 @@ TEST(CoordinatorTest, JoinWithoutStop) { TEST(CoordinatorTest, AllRunnersStopped) { Coordinator coord; MockQueueRunner* qr = new MockQueueRunner(&coord); - coord.RegisterRunner(std::unique_ptr(qr)); + TF_ASSERT_OK(coord.RegisterRunner(std::unique_ptr(qr))); EXPECT_FALSE(coord.AllRunnersStopped()); qr->Stop(); diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index cd6cc673275..5aaaa116cf0 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -49,7 +49,12 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { enqueue_op_names_.insert(enqueue_op_names_.end(), queue_runner_def.enqueue_op_name().begin(), queue_runner_def.enqueue_op_name().end()); - runs_ = enqueue_op_names_.size(); + size_t op_names_size = enqueue_op_names_.size(); + if (op_names_size > kint32max) { + return Status(error::INVALID_ARGUMENT, + "Enqueue ops to run cannot exceed kint32max"); + } + runs_ = static_cast(op_names_size); if (runs_ == 0) { return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run."); } @@ -77,11 +82,17 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { QueueRunner::~QueueRunner() { // Cannot run Stop() here because the session might already be closed or // destroyed. - Join(); + Join().IgnoreError(); } Status QueueRunner::Start(Session* sess) { return Start(sess, 0); } +Status QueueRunner::StartAndCollectCostGraph(Session* sess, + const RunOptions* run_options) { + SetRunArgumentsAndCostGraph(run_options); + return Start(sess, 0); +} + Status QueueRunner::Start(Session* sess, int wait_for) { counter_.reset(new BlockingCounter(runs_)); for (const string& enqueue_op : enqueue_op_names_) { @@ -109,12 +120,18 @@ Status QueueRunner::Start(Session* sess, int wait_for) { return Status::OK(); } +Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms, + const RunOptions* run_options) { + SetRunArgumentsAndCostGraph(run_options); + return Start(session, wait_for_ms); +} + void QueueRunner::Stop(Session* sess) { if (coord_ != nullptr) { coord_->WaitForStop(); } if (!cancel_op_name_.empty()) { - UpdateStatus(sess->Run({}, {}, {cancel_op_name_}, nullptr)); + UpdateStatus(RealRun(sess, cancel_op_name_, false)); } stopped_ = true; } @@ -149,7 +166,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { if (coord_ && coord_->ShouldStop()) { break; } - status = sess->Run({}, {}, {enqueue_op}, nullptr); + status = RealRun(sess, enqueue_op, true); if (first_iteration) { if (!status.ok()) { mutex_lock l(mu_); @@ -170,12 +187,14 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { // will be run anway in this case. if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) { if (last_run && !close_op_name_.empty()) { - UpdateStatus(sess->Run({}, {}, {close_op_name_}, nullptr)); + UpdateStatus(RealRun(sess, close_op_name_, false)); } } else if (!status.ok()) { + LOG(ERROR) << "Queue runner thread got a failure status: " + << status.ToString(); UpdateStatus(status); if (coord_) { - coord_->RequestStop(); + coord_->RequestStop().IgnoreError(); } } } @@ -185,4 +204,39 @@ Status QueueRunner::GetStatus() { return status_; } +Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const { + if (!cg_mu_) { + return Status(error::FAILED_PRECONDITION, + "This QueueRunner doesn't collect a cost graph."); + } + mutex_lock l(*cg_mu_); + cost_graph->MergeFrom(*cost_graph_); + return Status::OK(); +} + +void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions* run_options) { + cg_mu_.reset(new mutex()); + { + mutex_lock l(*cg_mu_); + cost_graph_.reset(new CostGraphDef()); + } + if (run_options) { + run_options_ = *run_options; + } +} + +Status QueueRunner::RealRun(Session* sess, const string& op, + bool update_costs) { + Status s; + if (update_costs && cg_mu_) { + RunMetadata metadata; + s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata); + mutex_lock l(*cg_mu_); + cost_graph_->Swap(metadata.mutable_cost_graph()); + } else { + s = sess->Run({}, {}, {op}, nullptr); + } + return s; +} + } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index bfe6a305936..71ed44c9c60 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/queue_runner.pb.h" #include "tensorflow/core/public/session.h" @@ -58,9 +59,16 @@ class QueueRunner : public RunnerInterface { /// Starts the queue runner with the given session. Status Start(Session* sess); + /// Starts the queue runner with the given session and sets the run arguments + /// for sess->Run. It also collects and stores the cost model. + Status StartAndCollectCostGraph(Session* sess, + const RunOptions* run_options = nullptr); + /// Starts the queue runner with the given session, and wait for up to the /// specified time (in milliseconds) for the queues to start to fill up. Status Start(Session* sess, int wait_for_ms); + Status StartAndCollectCostGraph(Session* session, int wait_for_ms, + const RunOptions* run_options = nullptr); /// Requests to stop and runs the cancel op. It would be called in a separate /// thread when coordinator is set. If there is no coordinator it should be @@ -74,8 +82,11 @@ class QueueRunner : public RunnerInterface { /// Returns the latest status. Status GetStatus(); + // Returns the stored cost model. + Status ExportCostGraph(CostGraphDef* cost_graph) const override; + private: - QueueRunner() : coord_(nullptr), stopped_(false) {} + QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {} // Initializes the instance with the QueueRunnerDef proto. Status Init(const QueueRunnerDef& queue_runner_def); @@ -94,6 +105,10 @@ class QueueRunner : public RunnerInterface { bool IsRunning() const override { return !stopped_; } + void SetRunArgumentsAndCostGraph(const RunOptions* run_options); + + Status RealRun(Session* sess, const string& op, bool update_costs); + string queue_name_; std::vector enqueue_op_names_; string close_op_name_; @@ -114,6 +129,10 @@ class QueueRunner : public RunnerInterface { mutex cb_mu_; std::vector> callbacks_; + + mutable std::unique_ptr cg_mu_; + std::unique_ptr cost_graph_ GUARDED_BY(cg_mu_); + RunOptions run_options_; }; } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 1661c5c91bb..da2fc03b6c0 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -44,6 +44,7 @@ using ops::FIFOQueue; using ops::QueueClose; using ops::QueueDequeue; using ops::QueueEnqueue; +using ops::RandomNormal; using ops::Square; using ops::Variable; @@ -84,7 +85,7 @@ QueueRunnerDef BuildQueueRunnerDef( const std::string& close_op, const std::string& cancel_op, const std::vector& queue_closed_error_codes) { QueueRunnerDef queue_runner_def; - *queue_runner_def.mutable_queue_name() = kQueueName; + *queue_runner_def.mutable_queue_name() = queue_name; for (const std::string& enqueue_op : enqueue_ops) { *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op; } @@ -293,7 +294,7 @@ TEST(QueueRunnerTest, StartTimeout) { // This will timeout since queue0 is not fed and queue1 is fetching data from // queue0. EXPECT_EQ(qr->Start(session.get(), 1).code(), Code::DEADLINE_EXCEEDED); - session->Close(); + TF_EXPECT_OK(session->Close()); } TEST(QueueRunnerTest, TestCoordinatorStop) { @@ -317,8 +318,8 @@ TEST(QueueRunnerTest, TestCoordinatorStop) { TF_EXPECT_OK(QueueRunner::New(queue_runner1, &coord, &qr1)); TF_CHECK_OK(qr1->Start(session.get())); - coord.RegisterRunner(std::move(qr0)); - coord.RegisterRunner(std::move(qr1)); + TF_EXPECT_OK(coord.RegisterRunner(std::move(qr0))); + TF_EXPECT_OK(coord.RegisterRunner(std::move(qr1))); std::vector dq; TF_EXPECT_OK(session->Run({}, {kDequeueOp1}, {}, &dq)); @@ -340,9 +341,70 @@ TEST(QueueRunnerTest, CallbackCalledOnError) { bool error_caught = false; qr->AddErrorCallback([&error_caught](const Status&) { error_caught = true; }); TF_EXPECT_OK(qr->Start(session.get())); - qr->Join(); + EXPECT_FALSE(qr->Join().ok()); EXPECT_TRUE(error_caught); } +TEST(QueueRunnerTest, RunMetaDataTest) { + Scope root = Scope::NewRootScope(); + auto q0 = FIFOQueue(root.WithOpName(kQueueName), {DataType::DT_FLOAT}); + Output rnd = RandomNormal(root.WithOpName("rnd"), {1, 1}, DataType::DT_FLOAT); + Output square = Square(root.WithOpName(kSquareOpName), rnd); + auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {square}); + auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0); + auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0, + QueueClose::CancelPendingEnqueues(true)); + auto dequeue0 = + QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_FLOAT}); + + GraphDef graph_def; + TF_EXPECT_OK(root.ToGraphDef(&graph_def)); + for (auto& node : *graph_def.mutable_node()) { + node.set_device("/cpu:0"); + } + SessionOptions sess_options; + sess_options.config.mutable_graph_options()->set_build_cost_model(1); + std::unique_ptr session(NewSession(sess_options)); + + TF_CHECK_OK(session->Create(graph_def)); + + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName, {kEnqueueOp0}, kCloseOp0, kCancelOp0, {}); + std::unique_ptr qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + RunOptions run_options; + TF_CHECK_OK(qr->StartAndCollectCostGraph(session.get(), &run_options)); + + // Make sure there was at least one element enqueued in q0: this prevents a + // race condition where we close the queue before it was populated. + std::vector dq0; + TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0)); + // Second call to run dequeue op is to make sure the cost graph has been + // stored. + TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0)); + + CostGraphDef cost_graph; + TF_CHECK_OK(qr->ExportCostGraph(&cost_graph)); + EXPECT_TRUE(cost_graph.node_size() > 0); + + qr->Stop(session.get()); +} + +TEST(QueueRunnerTest, NoRunMetaDataTest) { + GraphDef graph_def = BuildSimpleGraph(); + auto session = BuildSessionAndInitVariable(graph_def); + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); + std::unique_ptr qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + TF_CHECK_OK(qr->Start(session.get())); + + TF_EXPECT_OK(qr->Join()); + CostGraphDef cost_graph; + EXPECT_EQ(qr->ExportCostGraph(&cost_graph).code(), + error::FAILED_PRECONDITION); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc index f2ecd2eddc2..49d3cca3a4e 100644 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ b/tensorflow/cc/tutorials/example_trainer.cc @@ -227,7 +227,7 @@ int main(int argc, char* argv[]) { argv[dst++] = f; } argv[dst++] = nullptr; - argc = unknown_flags.size() + 1; + argc = static_cast(unknown_flags.size() + 1); tensorflow::port::InitMain(argv[0], &argc, &argv); tensorflow::example::ConcurrentSessions(opts); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index c52a56b6428..1f6fe28188c 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -20,6 +20,7 @@ cc_library( cc_test( name = "runtime_test", + size = "small", srcs = ["runtime_test.cc"], deps = [ ":runtime", @@ -73,7 +74,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:core_cpu", @@ -88,6 +89,7 @@ cc_library( cc_test( name = "codegen_test", + size = "small", srcs = ["codegen_test.cc"], data = ["codegen_test_h.golden"], deps = [ @@ -101,6 +103,7 @@ cc_test( cc_test( name = "tfcompile_util_test", + size = "small", srcs = ["tfcompile_util_test.cc"], deps = [ ":tfcompile_lib", @@ -123,9 +126,16 @@ cc_library( deps = [ ":tfcompile_lib", ":tfcompile_proto", + "//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags", + "//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags", "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", + "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", + "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/compiler/xla/legacy_flags:util_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/compiler/aot/benchmark.cc b/tensorflow/compiler/aot/benchmark.cc index 0c5e2c103ea..ff720382812 100644 --- a/tensorflow/compiler/aot/benchmark.cc +++ b/tensorflow/compiler/aot/benchmark.cc @@ -40,7 +40,7 @@ namespace benchmark { // the implementation without pulling in all of the Env dependencies. static double NowMicros() { struct timeval tv; - gettimeofday(&tv, NULL); + gettimeofday(&tv, nullptr); return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; } diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 042a72745a7..bbdb342a623 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -152,8 +152,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, string RewriteWithName(const string& name, string code, const std::vector>& rewrites) { str_util::ReplaceAllPairs(&code, rewrites); - str_util::ReplaceAll(&code, "{{NAME}}", name); - return code; + return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); } // Generate methods for args (inputs). @@ -366,7 +365,7 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config, #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 46d7c03006a..01963c6df46 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -15,7 +15,7 @@ #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 00c07932aca..ca17c5ab690 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -25,8 +25,9 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/tfcompile_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -199,17 +200,17 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config, for (const Fetch& fetch : config.fetch()) { missing_fetches.insert(TensorIdToString(fetch.id())); } - for (const Node* n : graph->nodes()) { + for (const Node* n : graph->op_nodes()) { if (n->type_string() == kArgOp) { string feed_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id)); if (missing_feeds.erase(feed_id) == 0) { return errors::Aborted(kArgOp, " node found with unknown feed id: ", feed_id); } } else if (n->type_string() == kRetvalOp) { string fetch_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id)); if (missing_fetches.erase(fetch_id) == 0) { return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ", fetch_id); @@ -233,7 +234,7 @@ Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { for (Node* n : graph.nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); auto insert_result = indexed_arg_nodes.insert({index, n}); if (!insert_result.second) { const Node* dup = insert_result.first->second; @@ -262,10 +263,10 @@ Status CreateXlaArgs(const Graph& graph, TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes)); for (const Node* node : arg_nodes) { XlaCompiler::Argument arg; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &arg.parameter)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name)); + arg.kind = XlaCompiler::Argument::kParameter; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } return Status::OK(); @@ -273,11 +274,11 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. -Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, - const FunctionLibraryDefinition* flib_def, +Status ConvertGraphToXla(xla::CompileOnlyClient* client, + std::unique_ptr graph, xla::Computation* computation, bool* has_context_arg) { // Create a device and context to convert the graph into an XLA computation. - XlaOpRegistry::RegisterJitKernels(); + XlaOpRegistry::RegisterCompilationKernels(); // Populate the context with args from the graph. for (Node* node : graph->nodes()) { node->set_assigned_device_name(DEVICE_CPU_XLA_JIT); @@ -288,19 +289,19 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + DeviceType device_type(DEVICE_CPU_XLA_JIT); + compiler_options.device_type = &device_type; + compiler_options.flib_def = &graph->flib_def(); + compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; XlaCompiler compiler(compiler_options); - std::unique_ptr flib_run(NewFunctionLibraryRuntime( - compiler.device_mgr(), Env::Default(), compiler.device(), - graph->versions().producer(), flib_def, OptimizerOptions())); XlaCompiler::CompilationResult result; - TF_RETURN_IF_ERROR(compiler.CompileGraph("tfcompile", std::move(graph), - flib_run.get(), xla_args, - false /* use_tuple_arg */, &result)); + TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "tfcompile", std::move(graph), + xla_args, &result)); *has_context_arg = result.requires_runtime_context; - *computation = std::move(result.computation); + *computation = std::move(*result.computation); int num_const_results = 0; for (int i = 0; i < result.outputs.size(); ++i) { @@ -334,7 +335,8 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, } // Compiles the XLA computation into executable code. -Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, +Status CompileXla(xla::CompileOnlyClient* client, + const xla::Computation& computation, const xla::cpu::CpuAotCompilationOptions& aot_opts, CompileResult* compile_result) { // Retrieves arg and result layouts from the computation. @@ -348,10 +350,11 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, compile_result->program_shape = *pshape_or.ValueOrDie(); xla::ProgramShape* pshape = &compile_result->program_shape; std::vector arg_layouts; + arg_layouts.reserve(pshape->parameters_size()); for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } - xla::LocalClient::AheadOfTimeComputationInstance instance; + xla::CompileOnlyClient::AotComputationInstance instance; instance.computation = &computation; instance.argument_layouts = std::move(arg_layouts); instance.result_layout = &pshape->result(); @@ -366,17 +369,17 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, std::move(aot_or.ValueOrDie().back())); compile_result->entry_point = aot_opts.entry_point_name(); compile_result->pointer_size = - xla::LocalClient::PointerSizeForTriple(aot_opts.triple()); + xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple()); return Status::OK(); } } // namespace Status InitGraph(const GraphDef& graph_def, const Config& config, - const MainFlags& flags, const FunctionLibraryDefinition* flib, - std::unique_ptr* graph) { + const MainFlags& flags, std::unique_ptr* graph) { TF_RETURN_IF_ERROR(ValidateConfig(config)); - std::unique_ptr g(new Graph(flib)); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library()); + std::unique_ptr g(new Graph(flib_def)); GraphDef copy_def(graph_def); TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(), 0 /*node_offset*/)); @@ -388,7 +391,6 @@ Status InitGraph(const GraphDef& graph_def, const Config& config, } Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, - const FunctionLibraryDefinition* flib, CompileResult* compile_result) { // Converts the graph into an XLA computation, and compiles the // computation. @@ -396,11 +398,11 @@ Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, namespace gpu = perftools::gputools; gpu::Platform* cpu_platform = gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); - xla::LocalClient* client = - xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie(); + xla::CompileOnlyClient* client = + xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) + .ValueOrDie(); xla::Computation computation; - TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), flib, - &computation, + TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation, &compile_result->has_context_arg)); if (!flags.debug_dir.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 8e9c64820ba..e929272b2e4 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -56,8 +56,7 @@ extern const char* const kDebugNameAttr; // compute the outputs. If dump_graphs is true, graph rewrites will be dumped // for debugging. Status InitGraph(const GraphDef& graph_def, const Config& config, - const MainFlags& flags, const FunctionLibraryDefinition* flib, - std::unique_ptr* graph); + const MainFlags& flags, std::unique_ptr* graph); // CompileResult describes the output of CompileGraph, where the object file // data and meta-information is available in aot. @@ -83,7 +82,6 @@ struct CompileResult { // // The XLA compilation options are specified in the flags. Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, - const FunctionLibraryDefinition* flib, CompileResult* result); } // namespace tfcompile diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc index 208de5498db..57727766661 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/aot/runtime.cc @@ -31,6 +31,8 @@ namespace { inline void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); +#elif defined(COMPILER_MSVC) + return _aligned_malloc(size, minimum_alignment); #else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN void* ptr = nullptr; // posix_memalign requires that the requested alignment be at least @@ -45,7 +47,13 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { #endif } -inline void aligned_free(void* aligned_memory) { free(aligned_memory); } +inline void aligned_free(void* aligned_memory) { +#if defined(COMPILER_MSVC) + _aligned_free(aligned_memory); +#else + free(aligned_memory); +#endif +} size_t align_to(size_t n, size_t align) { return (((n - 1) / align) + 1) * align; diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index ecb071a416c..6bfdf37caad 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -43,14 +43,16 @@ genrule( testonly = 1, outs = [ "test_graph_tfadd.pb", - "test_graph_tfadd_with_ckpt.pb", "test_graph_tfadd_with_ckpt.ckpt", - "test_graph_tfadd_with_ckpt_saver.pb", + "test_graph_tfadd_with_ckpt.pb", "test_graph_tfadd_with_ckpt_saver.ckpt", + "test_graph_tfadd_with_ckpt_saver.pb", "test_graph_tfadd_with_ckpt_saver.saver", + "test_graph_tffunction.pb", "test_graph_tfgather.pb", "test_graph_tfmatmul.pb", "test_graph_tfmatmulandadd.pb", + "test_graph_tfsplits.pb", ], cmd = "$(location :make_test_graphs) --out_dir $(@D)", tags = ["manual"], @@ -114,6 +116,24 @@ tf_library( tags = ["manual"], ) +tf_library( + name = "test_graph_tffunction", + testonly = 1, + config = "test_graph_tffunction.config.pbtxt", + cpp_class = "FunctionComp", + graph = "test_graph_tffunction.pb", + tags = ["manual"], +) + +tf_library( + name = "test_graph_tfsplits", + testonly = 1, + config = "test_graph_tfsplits.config.pbtxt", + cpp_class = "SplitsComp", + graph = "test_graph_tfsplits.pb", + tags = ["manual"], +) + cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -122,9 +142,11 @@ cc_test( ":test_graph_tfadd", ":test_graph_tfadd_with_ckpt", ":test_graph_tfadd_with_ckpt_saver", + ":test_graph_tffunction", ":test_graph_tfgather", ":test_graph_tfmatmul", ":test_graph_tfmatmulandadd", + ":test_graph_tfsplits", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 2a2d13dc498..a898eab1d1a 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -25,6 +25,7 @@ from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -71,7 +72,7 @@ def tfadd_with_ckpt_saver(out_dir): saver.save(sess, ckpt_file) # Without the SaverDef, the restore op won't be named correctly. saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir - with open(saver_file, 'w') as f: + with open(saver_file, 'wb') as f: f.write(saver.as_saver_def().SerializeToString()) @@ -95,13 +96,41 @@ def tfmatmulandadd(_): math_ops.add(x, y, name='x_y_sum') +def tffunction(_): + + @function.Defun(dtypes.int32, dtypes.int32) + def test_func(a, b): + return a + b + + x = constant_op.constant([1], name='x_const') + y = constant_op.constant([2], name='y_const') + test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg + + +def tfsplits(_): + """A more complex graph, including splits.""" + x = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='x') + y = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='y') + for _ in range(3): + x0, x1 = array_ops.split(x, 2, 0) + y0, y1 = array_ops.split(y, 2, 0) + x0 += 1 + y0 += 1 + z = math_ops.matmul(x, y, name='x_y_prod') + a = array_ops.concat([x0, y1], axis=0, name='concat_x0_y1') + b = array_ops.concat([y0, x1], axis=0, name='concat_y0_x1') + x = math_ops.matmul(a, b, name='a_b') + y = math_ops.add(x, z) + array_ops.identity(y, name='result') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() with g.as_default(): build_graph(out_dir) filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__) - with open(filename, 'w') as f: + with open(filename, 'wb') as f: f.write(g.as_graph_def().SerializeToString()) @@ -112,6 +141,8 @@ def main(_): write_graph(tfgather, FLAGS.out_dir) write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) + write_graph(tffunction, FLAGS.out_dir) + write_graph(tfsplits, FLAGS.out_dir) if __name__ == '__main__': @@ -121,7 +152,6 @@ if __name__ == '__main__': '--out_dir', type=str, default='', - help='Output directory for graphs, checkpoints and savers.' - ) + help='Output directory for graphs, checkpoints and savers.') FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt new file mode 100644 index 00000000000..eb9c1cacb7f --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "func_call" } +} diff --git a/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt new file mode 100644 index 00000000000..85fc7da4428 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt @@ -0,0 +1,18 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x" } + shape { + dim { size: 2 } + dim { size: 2 } + } +} +feed { + id { node_name: "y" } + shape { + dim { size: 2 } + dim { size: 2 } + } +} +fetch { + id { node_name: "result" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index f57d2859dfa..07562e59c8d 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -20,9 +20,11 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" +#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -376,6 +378,49 @@ TEST(TFCompileTest, MatMulAndAdd1) { } } +TEST(TFCompileTest, Function) { + // The function is equivalent to an addition + FunctionComp add_fn; + EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]); + EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]); + + add_fn.arg0() = 1; + add_fn.arg1() = 2; + EXPECT_TRUE(add_fn.Run()); + EXPECT_EQ(add_fn.error_msg(), ""); + EXPECT_EQ(add_fn.result0(), 3); + EXPECT_EQ(add_fn.result0_data()[0], 3); + EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]); +} + +TEST(TFCompileTest, Splits) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + SplitsComp fn; + + fn.set_thread_pool(&device); + // x = [[1, 2], [3, 4]] + fn.arg0(0, 0) = 1; + fn.arg0(0, 1) = 2; + fn.arg0(1, 0) = 3; + fn.arg0(1, 1) = 4; + + // y = [[10, 20], [30, 40]] + fn.arg1(0, 0) = 10; + fn.arg1(0, 1) = 20; + fn.arg1(1, 0) = 30; + fn.arg1(1, 1) = 40; + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); + const float expected[] = {7.86375557e+10, 1.34274679e+11, 1.92741717e+12, + 3.29964742e+12}; + EXPECT_NEAR(expected[0], fn.result0(0, 0), 1e4); + EXPECT_NEAR(expected[1], fn.result0(0, 1), 1e4); + EXPECT_NEAR(expected[2], fn.result0(1, 0), 1e4); + EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 6f2e0958fd3..4be4e0fbb39 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -279,7 +279,11 @@ def target_llvm_triple(): # TODO(toddw): Add target_triple for other targets. For details see: # http://llvm.org/docs/doxygen/html/Triple_8h_source.html return select({ + "//tensorflow:android_armeabi": "armv5-none-android", "//tensorflow:android_arm": "armv7-none-android", "//tensorflow:android_arm64": "aarch64-none-android", + "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", + "//tensorflow:darwin": "x86_64-none-darwin", "//conditions:default": "x86_64-pc-linux", }) diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 85ef9560bbf..6fed46b4329 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -23,9 +23,16 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/tfcompile.pb.h" #include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" #include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/util_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -52,7 +59,8 @@ const char kUsageHeader[] = "header file that gives access to the functionality in the object file.\n" "A typical invocation looks like this:\n" "\n" - " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n" + " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt " + "--cpp_class=\"mynamespace::MyComputation\"\n" "\n"; Status ReadProtoFile(const string& kind, const string& fname, @@ -73,6 +81,9 @@ void ParseTensorId(const string& name, TensorId* id) { Status Main(const MainFlags& flags) { // Process config. Config config; + if (flags.config.empty()) { + return errors::InvalidArgument("Must specify --config"); + } TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config)); TF_RETURN_IF_ERROR(ValidateConfig(config)); if (flags.dump_fetch_nodes) { @@ -85,15 +96,16 @@ Status Main(const MainFlags& flags) { } // Read and initialize the graph. + if (flags.graph.empty()) { + return errors::InvalidArgument("Must specify --graph"); + } GraphDef graph_def; TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def)); std::unique_ptr graph; - FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library()); - TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &flib, &graph)); + TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &graph)); CompileResult compile_result; - TF_RETURN_IF_ERROR( - CompileGraph(std::move(graph), flags, &flib, &compile_result)); + TF_RETURN_IF_ERROR(CompileGraph(std::move(graph), flags, &compile_result)); // Write output files. Env* env = Env::Default(); @@ -101,6 +113,9 @@ Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, StringPiece(obj.data(), obj.size()))); HeaderOpts header_opts; + if (flags.cpp_class.empty()) { + return errors::InvalidArgument("Must specify --cpp_class"); + } TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name, &header_opts.namespaces)); string header; @@ -121,9 +136,16 @@ int main(int argc, char** argv) { std::vector flag_list; AppendMainFlags(&flag_list, &flags); + xla::legacy_flags::AppendAliasAnalysisFlags(&flag_list); + xla::legacy_flags::AppendBufferAssignmentFlags(&flag_list); xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); + xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::legacy_flags::AppendLlvmUtilFlags(&flag_list); + xla::legacy_flags::AppendServiceFlags(&flag_list); + xla::legacy_flags::AppendUtilFlags(&flag_list); tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); @@ -131,12 +153,16 @@ int main(int argc, char** argv) { QCHECK(parsed_flags_ok) << "\n" << usage; tensorflow::port::InitMain(usage.c_str(), &argc, &argv); - QCHECK(argc == 1 && !flags.config.empty() && - (flags.dump_fetch_nodes || - (!flags.graph.empty() && !flags.entry_point.empty()))) - << "\n" - << usage; - - TF_QCHECK_OK(tensorflow::tfcompile::Main(flags)); + QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " + "other than flags\n\n" + << usage; + tensorflow::Status status = tensorflow::tfcompile::Main(flags); + if (status.code() == tensorflow::error::INVALID_ARGUMENT) { + std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n" + << usage; + return 1; + } else { + TF_QCHECK_OK(status); + } return 0; } diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/aot/tfcompile_util_test.cc index 108ab1eab7b..c321d3ff4c7 100644 --- a/tensorflow/compiler/aot/tfcompile_util_test.cc +++ b/tensorflow/compiler/aot/tfcompile_util_test.cc @@ -24,7 +24,7 @@ namespace tensorflow { namespace tfcompile { namespace { -void ExpectErrorContains(Status status, StringPiece str) { +void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) << "expected error: " << status.error_message() << " to contain: " << str; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 414e152cee4..5f857191da7 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -18,7 +18,23 @@ package( default_visibility = [":internal"], ) +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +# This target can be used by XLA device plugins to prevent circular +# dependencies, and provides access to all of the required headers +# for building a device library. +cc_header_only_library( + name = "xla_jit_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_gpu_device", + ":xla_gpu_jit", + ], +) # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -29,6 +45,7 @@ cc_library( ":xla_cpu_jit", ":xla_gpu_device", ":xla_gpu_jit", + "//tensorflow/compiler/plugin", ], alwayslink = 1, ) @@ -38,7 +55,7 @@ cc_library( visibility = [":friends"], deps = [ ":jit_compilation_passes", - ":xla_local_launch_op", + "//tensorflow/compiler/jit/kernels:xla_local_launch_op", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", ], @@ -48,12 +65,12 @@ cc_library( cc_library( name = "xla_gpu_jit", visibility = [":friends"], - deps = [ + deps = if_cuda([ ":jit_compilation_passes", - ":xla_local_launch_op", + "//tensorflow/compiler/jit/kernels:xla_local_launch_op", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", - ], + ]), alwayslink = 1, ) @@ -64,8 +81,10 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", + "//tensorflow/compiler/jit/kernels:xla_device_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", ], @@ -79,8 +98,10 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", + "//tensorflow/compiler/jit/kernels:xla_device_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", ], @@ -105,26 +126,22 @@ cc_library( srcs = [ "xla_device.cc", "xla_device_context.cc", - "xla_device_launch_op.cc", "xla_device_ops.cc", ], hdrs = [ "xla_device.h", "xla_device_context.h", - "xla_device_launch_op.h", "xla_device_ops.h", ], deps = [ ":common", ":jit_compilation_passes", - ":xla_compilation_cache", - ":xla_local_launch_op", + "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", @@ -132,9 +149,9 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:tensorflow_opensource", - "//tensorflow/core/kernels:assign_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:identity_op", @@ -142,7 +159,6 @@ cc_library( "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:variable_ops", ], - alwayslink = 1, ) cc_library( @@ -155,13 +171,13 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:variable_ops", ], ) @@ -175,27 +191,41 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "graph_to_functiondef", + srcs = ["graph_to_functiondef.cc"], + hdrs = ["graph_to_functiondef.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "compilation_passes", srcs = [ "build_xla_launch_ops_pass.cc", "encapsulate_subgraphs_pass.cc", - "graph_to_functiondef.cc", "mark_for_compilation_pass.cc", ], hdrs = [ "build_xla_launch_ops_pass.h", "encapsulate_subgraphs_pass.h", - "graph_to_functiondef.h", "mark_for_compilation_pass.h", ], deps = [ ":common", - ":parallel_check_op", - ":xla_local_launch_op", + ":graph_to_functiondef", + ":union_find", "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/jit/kernels:parallel_check_op", + "//tensorflow/compiler/jit/kernels:xla_local_launch_op", "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/compiler/jit/ops:parallel_check_op", + "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -208,6 +238,11 @@ cc_library( ], ) +cc_library( + name = "union_find", + hdrs = ["union_find.h"], +) + cc_test( name = "compilation_passes_test", size = "small", @@ -217,8 +252,9 @@ cc_test( "mark_for_compilation_pass_test.cc", ], deps = [ + ":common", ":compilation_passes", - ":xla_local_launch_op", + ":graph_to_functiondef", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", @@ -226,48 +262,14 @@ cc_test( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", ], ) -cc_library( - name = "xla_local_launch_op", - srcs = ["xla_local_launch_op.cc"], - hdrs = ["xla_local_launch_op.h"], - deps = [ - ":common", - ":xla_compilation_cache", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - "//tensorflow/core:tensorflow_opensource", - ], - alwayslink = 1, -) - -tf_kernel_library( - name = "parallel_check_op", - srcs = ["parallel_check_op.cc"], - visibility = [":friends"], - deps = [ - "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], - alwayslink = 1, -) - # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index 8fde1974005..ef56ccf8e8f 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -16,14 +16,13 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/kernels/xla_local_launch_op.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" -#include "tensorflow/compiler/jit/xla_local_launch_op.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" @@ -32,7 +31,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/version.h" namespace tensorflow { @@ -40,14 +38,16 @@ namespace tensorflow { static Status BuildLaunchNode( const string& nodename, const string& function_name, const AttrValueMap& function_attr, const string& device_name, - const DataTypeVector& constant_dtypes, const DataTypeVector& arg_dtypes, - const DataTypeVector& result_dtypes, Graph* graph, Node** node) { + const DataTypeVector& constant_dtypes, int num_resources, + const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes, + Graph* graph, Node** node) { NodeDef def; def.set_name(graph->NewName(nodename)); def.set_op("_XlaLaunch"); def.set_device(device_name); AddNodeAttr("Tconstants", constant_dtypes, &def); AddNodeAttr("Targs", arg_dtypes, &def); + AddNodeAttr("Nresources", num_resources, &def); AddNodeAttr("Tresults", result_dtypes, &def); NameAttrList function; function.set_name(function_name); @@ -62,25 +62,32 @@ static Status BuildLaunchNode( static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { VLOG(2) << "Replacing " << node->name() << " with XlaLaunch"; - int num_constant_args; + int num_constant_args, num_resource_args; TF_RETURN_IF_ERROR( - GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args)); + GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args)); + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args)); - if (num_constant_args < 0 || num_constant_args > node->input_types().size()) { + if (num_constant_args < 0 || num_resource_args < 0 || + num_constant_args + num_resource_args > node->num_inputs()) { return errors::InvalidArgument( - "Invalid number of constant arguments to XLA kernel"); + "Invalid number of constant/resource arguments to XLA kernel."); } + const int num_nonconst_args = + node->num_inputs() - num_constant_args - num_resource_args; + DataTypeVector const_dtypes(node->input_types().begin(), node->input_types().begin() + num_constant_args); - DataTypeVector arg_dtypes(node->input_types().begin() + num_constant_args, - node->input_types().end()); + DataTypeVector arg_dtypes( + node->input_types().begin() + num_constant_args, + node->input_types().begin() + num_constant_args + num_nonconst_args); // Build a _XlaLaunch operator to execute the function body. Node* launch_node; - TF_RETURN_IF_ERROR( - BuildLaunchNode(graph->NewName(node->name()), node->type_string(), - node->def().attr(), node->def().device(), const_dtypes, - arg_dtypes, node->output_types(), graph, &launch_node)); + TF_RETURN_IF_ERROR(BuildLaunchNode( + graph->NewName(node->name()), node->type_string(), node->def().attr(), + node->requested_device(), const_dtypes, num_resource_args, arg_dtypes, + node->output_types(), graph, &launch_node)); launch_node->set_assigned_device_name(node->assigned_device_name()); // Copy incoming edges to the launch node. @@ -116,9 +123,9 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); - for (Node* n : graph->nodes()) { + for (Node* n : graph->op_nodes()) { // In all cases, only try to compile computational nodes. - if (!n->IsOp() || n->IsSend() || n->IsRecv() || n->IsControlFlow()) { + if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { continue; } @@ -128,6 +135,11 @@ Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) { TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n)); } } + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph, + options.flib_def); + } return Status::OK(); } @@ -151,7 +163,7 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op()); } // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterJitKernels(); + XlaOpRegistry::RegisterCompilationKernels(); if (!IsCompilable(flr, ndef)) { // ndef is calling a function that XLA can't compile. return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString()); @@ -159,7 +171,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, FunctionLibraryRuntime::Handle handle; // If ndef is not instantiable, e.g., the function does not exist, // simply bail out. - TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle)); + TF_RETURN_IF_ERROR( + flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); const FunctionBody* fbody = flr->GetFunctionBody(handle); CHECK(fbody); // Can't be nullptr since we just instantiated it. std::vector const_args(fbody->arg_types.size()); @@ -179,6 +192,7 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, launch_def.set_op("_XlaLaunch"); launch_def.set_device(flr->device()->name()); AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def); + AddNodeAttr("Nresources", 0, &launch_def); AddNodeAttr("Targs", fbody->arg_types, &launch_def); AddNodeAttr("Tresults", fbody->ret_types, &launch_def); NameAttrList func; diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc index b20ad53ef64..f847d66f3c6 100644 --- a/tensorflow/compiler/jit/defs.cc +++ b/tensorflow/compiler/jit/defs.cc @@ -18,5 +18,6 @@ limitations under the License. namespace tensorflow { const char* const kXlaCompileAttr = "_XlaCompile"; +const char* const kXlaScopeAttr = "_XlaScope"; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h index ddc830cb770..a3aabc949db 100644 --- a/tensorflow/compiler/jit/defs.h +++ b/tensorflow/compiler/jit/defs.h @@ -23,6 +23,7 @@ namespace tensorflow { // Name of attribute used to tag operators for compilation with XLA extern const char* const kXlaCompileAttr; // "_XlaCompile" +extern const char* const kXlaScopeAttr; // "_XlaScope" } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index c1e61462085..14d8f2ab351 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" @@ -46,6 +45,7 @@ namespace tensorflow { const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel"; const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; +const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; namespace { @@ -87,9 +87,12 @@ class Encapsulator { // Build a FunctionDef for each subgraph, and add it 'library'. The values of // the 'group_attribute' annotations become the function names. + // If 'reuse_existing_functions' is set, use an existing function with the + // same name, if any. // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before // function conversion. Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, FunctionLibraryDefinition* library); // Write a copy of the input graph to 'graph_out', where the subgraphs are @@ -109,8 +112,8 @@ class Encapsulator { // returned by _Retval nodes. std::unique_ptr graph; - // Which device are these nodes on? Used both to check that all nodes - // are assigned to the same device, and to assign a device to the call node. + // Which device are these nodes on? Used to assign a device to the call + // node. string device; // NodeDef for the function call node. @@ -161,7 +164,7 @@ static const char* const kRetValOp = "_Retval"; // none. string Encapsulator::GetFunctionNameAttr(Node const* node) const { string attr; - if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) { + if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) { attr.clear(); } return attr; @@ -174,8 +177,7 @@ Status Encapsulator::SplitIntoSubgraphs() { std::unordered_map node_images; // Copy all marked nodes to a subgraph. Do nothing for unmarked nodes. - for (Node* node : graph_in_->nodes()) { - if (node->IsSource() || node->IsSink()) continue; + for (Node* node : graph_in_->op_nodes()) { string func_id = GetFunctionNameAttr(node); if (func_id.empty()) continue; @@ -189,16 +191,10 @@ Status Encapsulator::SplitIntoSubgraphs() { image->ClearAttr(group_attribute_); node_images[node] = image; - // Check the device matches any existing device. - string device = node->assigned_device_name().empty() - ? node->def().device() - : node->assigned_device_name(); - if (subgraph.device.empty()) { - subgraph.device = device; - } else if (subgraph.device != device) { - s.Update(errors::InvalidArgument( - "Mismatched devices for nodes to be grouped by Encapsulator")); + subgraph.device = node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name(); } } @@ -235,9 +231,16 @@ Status Encapsulator::SplitIntoSubgraphs() { // Create a new _Retval node DataType dtype = edge->src()->output_type(edge->src_output()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported: tensor ", + edge->src()->name(), ":", edge->src_output()); + } + NodeDef ret_def; ret_def.set_op(kRetValOp); - ret_def.set_name(src_subgraph.graph->NewName("output")); + ret_def.set_name(strings::StrCat(edge->src()->name(), "_", + edge->src_output(), "_retval")); AddNodeAttr("T", dtype, &ret_def); AddNodeAttr("index", ret_index, &ret_def); Node* ret = src_subgraph.graph->AddNode(ret_def, &s); @@ -262,8 +265,16 @@ Status Encapsulator::SplitIntoSubgraphs() { // This is the first time we have seen this tensor. Create an _Arg node. DataType dtype = edge->dst()->input_type(edge->dst_input()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported: tensor ", + edge->src()->name(), ":", edge->src_output()); + } + NodeDef arg_def; - NodeDefBuilder builder(dst_subgraph.graph->NewName("input"), kArgOp); + NodeDefBuilder builder(strings::StrCat(edge->src()->name(), "_", + edge->src_output(), "_arg"), + kArgOp); builder.Attr("T", dtype); builder.Attr("index", arg_index); s = builder.Finalize(&arg_def); @@ -290,11 +301,11 @@ Status Encapsulator::SplitIntoSubgraphs() { } Status Encapsulator::BuildFunctionDefs( - const RewriteSubgraphFn& rewrite_subgraph_fn, + const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { // For each subgraph, build a FunctionDef. for (auto& subgraph_entry : subgraphs_) { - const string& name = subgraph_entry.first; + string name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; subgraph.call_node_def.set_op(name); @@ -331,6 +342,8 @@ Status Encapsulator::BuildFunctionDefs( for (auto& result : subgraph.results) { result.second = output_permutation[result.second]; } + + name = subgraph.call_node_def.op(); } FunctionDef fdef; @@ -345,7 +358,9 @@ Status Encapsulator::BuildFunctionDefs( strings::StrCat("encapsulate_fdef_", name), fdef); } - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + if (!reuse_existing_functions || library->Find(name) == nullptr) { + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } } return Status::OK(); } @@ -422,8 +437,7 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, std::unordered_map node_images; // Copy all unmarked nodes to the output graph. - for (Node* node : graph_in_->nodes()) { - if (node->IsSource() || node->IsSink()) continue; + for (Node* node : graph_in_->op_nodes()) { string func_id = GetFunctionNameAttr(node); // Don't copy nodes that going to be encapsulated, unless parallel checking @@ -544,14 +558,16 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, Status EncapsulateSubgraphsInFunctions( string group_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking, - std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library) { Status s; Encapsulator encapsulator(std::move(group_attribute), &graph_in); s = encapsulator.SplitIntoSubgraphs(); if (!s.ok()) return s; - s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, library); + s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, + reuse_existing_functions, library); if (!s.ok()) return s; std::unique_ptr out(new Graph(library)); @@ -563,14 +579,29 @@ Status EncapsulateSubgraphsInFunctions( return s; } +// Finds the types of the _Arg nodes, indexed by position. +static Status GetArgTypes(const Graph& graph, DataTypeVector* types) { + for (Node* n : graph.op_nodes()) { + if (n->type_string() == kArgOp) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + if (index < 0 || index >= types->size()) { + return errors::InvalidArgument("Invalid argument number"); + } + (*types)[index] = n->output_type(0); + } + } + return Status::OK(); +} + // Renumber the indices of _Arg nodes in a graph, according to // 'permutation' that maps old indices to new indices. static Status RenumberArguments(Graph* graph, const std::vector& permutation) { - for (Node* n : graph->nodes()) { + for (Node* n : graph->op_nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); if (index < 0 || index >= permutation.size()) { return errors::InvalidArgument("Invalid argument number"); } @@ -604,19 +635,40 @@ Status EncapsulateSubgraphsPass::Run( // Optimize the subgraph. OptimizeGraph(flr.get(), subgraph); - std::vector const_args(input_permutation->size()); + const int num_args = input_permutation->size(); + std::vector const_args(num_args); TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); + DataTypeVector arg_types(num_args); + TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); + // Compute a permutation of the arguments such that the constant arguments // are first. const int num_consts = std::count(const_args.begin(), const_args.end(), true); + + const int num_resources = + std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); + const int num_nonconsts = num_args - num_resources - num_consts; + if (num_nonconsts < 0) { + return errors::Internal("num_nonconsts should be >= 0, was ", + num_nonconsts); + } + int const_pos = 0; int arg_pos = num_consts; - for (int i = 0; i < const_args.size(); ++i) { + int resource_pos = num_consts + num_nonconsts; + for (int i = 0; i < num_args; ++i) { if (const_args[i]) { + if (arg_types[i] == DT_RESOURCE) { + return errors::Internal( + "Resource arguments cannot be constant (argument ", i, ")"); + } (*input_permutation)[i] = const_pos; ++const_pos; + } else if (arg_types[i] == DT_RESOURCE) { + (*input_permutation)[i] = resource_pos; + ++resource_pos; } else { (*input_permutation)[i] = arg_pos; ++arg_pos; @@ -631,12 +683,14 @@ Status EncapsulateSubgraphsPass::Run( AddNodeAttr(kXlaCompiledKernelAttr, true, node); AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); return Status::OK(); }; TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( kXlaClusterAttr, **options.graph, rewrite_subgraph, - flags->tf_xla_parallel_checking, &graph_out, library)); + flags->tf_xla_parallel_checking, /*reuse_existing_functions=*/false, + &graph_out, library)); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, @@ -650,7 +704,7 @@ Status EncapsulateSubgraphsPass::Run( bool IsXlaCompiledKernel(const Node& node) { bool is_compiled = false; bool has_compilation_attr = - GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() && + GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() && is_compiled; return has_compilation_attr ? is_compiled : false; } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index ffd39f0b77f..b0987f76c91 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -34,6 +34,8 @@ namespace tensorflow { // 'input_permutation' and 'output_permutation' are initialized to the identity // permutation. 'nodedef' is the NodeDef for the call to the function under // construction, provided to allow additional attributes to be set. +// The rewrite may also change the NodeDef's operator name, and that +// name will be used as the name of the generated function. typedef std::function* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def)> @@ -53,6 +55,9 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via _XlaLaunch operators. @@ -70,12 +76,22 @@ extern const char* const kXlaCompiledKernelAttr; // Does `node` have the kXlaCompiledKernelAttr attribute? bool IsXlaCompiledKernel(const Node& node); -// Functions produce by the EncapsulateSubgraphs pass have their arguments -// ordered such that compile-time constant arguments are first in the argument -// order. The functions are annotated with the following attribute giving the -// number of constant arguments. +// Functions produced by the EncapsulateSubgraphs pass have their arguments in +// the order: +// 1) compile-time constant arguments, in host memory, +// 2) other arguments, in device memory. +// 3) resource variable arguments, in host memory. Note that only the resource +// Tensor itself is in host memory; the underlying value may be in device +// memory. +// The functions are annotated with the following attributes that describe how +// many constant and resource arguments there are: + +// Name of the attribute containing the number of constant arguments. extern const char* const kXlaNumConstantArgsAttr; +// Name of the attribute containing the number of resource variable arguments. +extern const char* const kXlaNumResourceArgsAttr; + class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index c85882e0d7f..4a1dbaf05dc 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/graph/equal_graph_def.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace { @@ -76,7 +78,7 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, #define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \ do { \ string diff; \ - EXPECT_TRUE(EqualFunctionDefLibrary(actual, expected, &diff)) \ + EXPECT_TRUE(EqualFunctionDefLibrary(expected, actual, &diff)) \ << diff << "\nActual: " << actual.DebugString(); \ } while (false) @@ -101,15 +103,15 @@ Node* Input(const GraphDefBuilder::Options& opts) { } Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { - return ops::UnaryOp("UnaryTest", a, opts); + return ops::UnaryOp("UnaryTest", std::move(a), opts); } Node* Binary(ops::NodeOut a, ops::NodeOut b, const GraphDefBuilder::Options& opts) { - return ops::BinaryOp("BinaryTest", a, b, opts); + return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts); } -Node* AddNLike(std::vector inputs, +Node* AddNLike(const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest", @@ -127,7 +129,7 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval", opts.op_registry()); - node_builder.Input(a).Attr("index", index); + node_builder.Input(std::move(a)).Attr("index", index); return opts.FinalizeBuilder(&node_builder); } @@ -144,8 +146,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph, - /* rewrite_subgraph_fn= */ {}, - /* parallel_checking= */ false, + /*rewrite_subgraph_fn=*/{}, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); if (!s.ok()) return s; @@ -168,7 +171,7 @@ TEST(EncapsulateSubgraphsTest, NoFunctions) { GraphDef graphdef_in; FunctionDefLibrary library_in; - builder.ToGraphDef(&graphdef_in); + TF_EXPECT_OK(builder.ToGraphDef(&graphdef_in)); *library_in.add_function() = test::function::XTimesTwo(); GraphDef graphdef_out = graphdef_in; @@ -195,7 +198,7 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr( "_encapsulate", "F1")); Binary(a, d, b1.opts().WithName("E")); - b1.ToGraphDef(&graphdef); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } TF_EXPECT_OK(Encapsulate(&graphdef, &library)); @@ -205,12 +208,12 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"input__0:float", "input__1:float"}, {"output__2:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"c_0_retval:float"}, {}, { - {{"C"}, "UnaryTest", {"input__0"}}, - {{"c"}, "BinaryTest", {"input__1", "C:o:0"}, {}, {"C"}}, + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, }, - {{"output__2", "c:o:0"}}); + {{"c_0_retval", "c:o:0"}}); { std::unique_ptr lib_def( @@ -224,7 +227,7 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { Node* call = b2.opts().FinalizeBuilder(&node_builder); Binary(a, call, b2.opts().WithName("E")); - b2.ToGraphDef(&graphdef_expected); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } // If there are no marked nodes, funcification should be a no-op. @@ -251,7 +254,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr( "_encapsulate", "F2")); Binary(a, d, b1.opts().WithName("E")); - b1.ToGraphDef(&graphdef); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } TF_EXPECT_OK(Encapsulate(&graphdef, &library)); @@ -261,17 +264,17 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"input__0:float"}, {"output__1:float"}, {}, + "F1", {"a_0_arg:float"}, {"c_0_retval:float"}, {}, { - {{"C"}, "UnaryTest", {"input__0"}}, + {{"C"}, "UnaryTest", {"a_0_arg"}}, }, - {{"output__1", "C:o:0"}}); + {{"c_0_retval", "C:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"input__0:float", "input__1:float"}, {"output__2:float"}, {}, + "F2", {"b_0_arg:float", "c_0_arg:float"}, {"d_0_retval:float"}, {}, { - {{"D"}, "BinaryTest", {"input__0", "input__1"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "c_0_arg"}}, }, - {{"output__2", "D:o:0"}}); + {{"d_0_retval", "D:o:0"}}); { std::unique_ptr lib_def( @@ -290,7 +293,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { Node* call2 = b2.opts().FinalizeBuilder(&nb2); Binary(a, call2, b2.opts().WithName("E")); - b2.ToGraphDef(&graphdef_expected); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } // If there are no marked nodes, funcification should be a no-op. @@ -340,7 +343,8 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, &graph, &library)); + /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph, + &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; EXPECT_EQ(expected_nodes, GraphNodes(*graph)); @@ -371,7 +375,8 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/true, &graph, &library)); + /*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph, + &library)); std::vector expected_nodes = { "add1", "add2", "cluster1", "cluster1_parallel_check/_0", diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc index f5b99226acd..5cdbebd88ee 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef.cc @@ -120,14 +120,12 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, std::unordered_map return_values; NodeNameMapping node_names; - for (Node const* node : graph.nodes()) { - if (!node->IsOp()) continue; - + for (Node const* node : graph.op_nodes()) { if (node->type_string() == kArgOp) { int index; DataType type; - GetNodeAttr(node->def(), "T", &type); - GetNodeAttr(node->def(), "index", &index); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); while (fdef->signature().input_arg_size() <= index) { fdef->mutable_signature()->add_input_arg(); } @@ -143,8 +141,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, if (node->type_string() == kRetValOp) { int index; DataType type; - GetNodeAttr(node->def(), "T", &type); - GetNodeAttr(node->def(), "index", &index); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); while (fdef->signature().output_arg_size() <= index) { fdef->mutable_signature()->add_output_arg(); } @@ -161,9 +159,11 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, } NodeDef* node_def = fdef->add_node_def(); - node_def->CopyFrom(node->def()); + *node_def = node->def(); + if (!node->assigned_device_name().empty()) { + node_def->set_device(node->assigned_device_name()); + } node_def->set_name(node_names.Uniquify(node->name())); - node_def->clear_device(); // Reset input names based on graph rather than the NodeDef. node_def->clear_input(); @@ -185,7 +185,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, } // Add regular inputs - for (int i = 0; i < in_edges.size(); ++i) { + for (std::vector::size_type i = 0; i < in_edges.size(); ++i) { const Edge* edge = in_edges[i]; if (edge == nullptr) { return errors::InvalidArgument( @@ -204,8 +204,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, // Populate tensor_renaming. NameRangeMap output_ranges; - TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr, - &output_ranges)); + TF_RETURN_IF_ERROR( + NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); for (const auto& output : output_ranges) { for (int i = output.second.first; i < output.second.second; ++i) { const string tensor_name = strings::StrCat( diff --git a/tensorflow/compiler/jit/graph_to_functiondef_test.cc b/tensorflow/compiler/jit/graph_to_functiondef_test.cc index 04b2385c9c9..5c09e96a4c2 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef_test.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/graph/equal_graph_def.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace { @@ -54,7 +54,7 @@ TEST(GraphToFunctionDefTest, Basics) { auto h = ops::_Retval(root.WithOpName("H"), g, 0); GraphDef graph_def; - root.ToGraphDef(&graph_def); + TF_EXPECT_OK(root.ToGraphDef(&graph_def)); std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphConstructorOptions options; diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 87d5de09d14..bc68afb322b 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -76,7 +76,7 @@ struct GraphCycles::Rep { GraphCycles::GraphCycles() : rep_(new Rep) {} GraphCycles::~GraphCycles() { - for (int i = 0; i < rep_->nodes_.size(); i++) { + for (Vec::size_type i = 0; i < rep_->nodes_.size(); i++) { delete rep_->nodes_[i]; } delete rep_; @@ -85,7 +85,7 @@ GraphCycles::~GraphCycles() { bool GraphCycles::CheckInvariants() const { Rep* r = rep_; NodeSet ranks; // Set of ranks seen so far. - for (int32 x = 0; x < r->nodes_.size(); x++) { + for (Vec::size_type x = 0; x < r->nodes_.size(); x++) { Node* nx = r->nodes_[x]; if (nx->visited) { LOG(FATAL) << "Did not clear visited marker on node " << x; @@ -108,7 +108,7 @@ int32 GraphCycles::NewNode() { if (rep_->free_nodes_.empty()) { Node* n = new Node; n->visited = false; - n->data = NULL; + n->data = nullptr; n->rank = rep_->nodes_.size(); rep_->nodes_.push_back(n); return n->rank; @@ -116,7 +116,7 @@ int32 GraphCycles::NewNode() { // Preserve preceding rank since the set of ranks in use must be // a permutation of [0,rep_->nodes_.size()-1]. int32 r = rep_->free_nodes_.back(); - rep_->nodes_[r]->data = NULL; + rep_->nodes_[r]->data = nullptr; rep_->free_nodes_.pop_back(); return r; } @@ -259,7 +259,7 @@ static void Reorder(GraphCycles::Rep* r) { r->deltaf_.end(), r->merged_.begin()); // Assign the ranks in order to the collected list. - for (int32 i = 0; i < r->list_.size(); i++) { + for (Vec::size_type i = 0; i < r->list_.size(); i++) { r->nodes_[r->list_[i]]->rank = r->merged_[i]; } } @@ -277,7 +277,7 @@ static void Sort(const Vec& nodes, Vec* delta) { } static void MoveToList(GraphCycles::Rep* r, Vec* src, Vec* dst) { - for (int32 i = 0; i < src->size(); i++) { + for (Vec::size_type i = 0; i < src->size(); i++) { int32 w = (*src)[i]; (*src)[i] = r->nodes_[w]->rank; // Replace src entry with its rank r->nodes_[w]->visited = false; // Prepare for future DFS calls @@ -286,7 +286,7 @@ static void MoveToList(GraphCycles::Rep* r, Vec* src, Vec* dst) { } static void ClearVisitedBits(GraphCycles::Rep* r, const Vec& nodes) { - for (int32 i = 0; i < nodes.size(); i++) { + for (Vec::size_type i = 0; i < nodes.size(); i++) { r->nodes_[nodes[i]]->visited = false; } } @@ -332,7 +332,7 @@ int GraphCycles::FindPath(int32 x, int32 y, int max_path_len, } bool GraphCycles::IsReachable(int32 x, int32 y) const { - return FindPath(x, y, 0, NULL) > 0; + return FindPath(x, y, 0, nullptr) > 0; } bool GraphCycles::IsReachableNonConst(int32 x, int32 y) { diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc index f27a616ac9d..e47b782207e 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc @@ -230,7 +230,7 @@ TEST(GraphCycles, RandomizedTest) { int new_node = graph_cycles.NewNode(); ASSERT_NE(-1, new_node); VLOG(1) << "adding node " << new_node; - ASSERT_EQ(0, graph_cycles.GetNodeData(new_node)); + ASSERT_EQ(nullptr, graph_cycles.GetNodeData(new_node)); graph_cycles.SetNodeData( new_node, reinterpret_cast( static_cast(new_node + kDataOffset))); @@ -243,7 +243,7 @@ TEST(GraphCycles, RandomizedTest) { break; case 1: // Remove a node - if (nodes.size() > 0) { + if (!nodes.empty()) { int node_index = RandomNode(&rnd, &nodes); int node = nodes[node_index]; nodes[node_index] = nodes.back(); @@ -263,7 +263,7 @@ TEST(GraphCycles, RandomizedTest) { break; case 2: // Add an edge - if (nodes.size() > 0) { + if (!nodes.empty()) { int from = RandomNode(&rnd, &nodes); int to = RandomNode(&rnd, &nodes); if (EdgeIndex(&edges, nodes[from], nodes[to]) == -1) { @@ -282,7 +282,7 @@ TEST(GraphCycles, RandomizedTest) { break; case 3: // Remove an edge - if (edges.size() > 0) { + if (!edges.empty()) { int i = RandomEdge(&rnd, &edges); int from = edges[i].from; int to = edges[i].to; @@ -296,7 +296,7 @@ TEST(GraphCycles, RandomizedTest) { break; case 4: // Check a path - if (nodes.size() > 0) { + if (!nodes.empty()) { int from = RandomNode(&rnd, &nodes); int to = RandomNode(&rnd, &nodes); int32 path[2 * kMaxNodes]; @@ -343,7 +343,7 @@ TEST(GraphCycles, RandomizedTest) { ASSERT_NE(-1, new_node); VLOG(1) << "adding node " << new_node; ASSERT_GE(new_node, 0); - ASSERT_EQ(0, graph_cycles.GetNodeData(new_node)); + ASSERT_EQ(nullptr, graph_cycles.GetNodeData(new_node)); graph_cycles.SetNodeData( new_node, reinterpret_cast( static_cast(new_node + kDataOffset))); diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD new file mode 100644 index 00000000000..c4116cb8b52 --- /dev/null +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -0,0 +1,74 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//tensorflow/compiler/tf2xla:internal", + ], +) + +cc_library( + name = "xla_local_launch_op", + srcs = ["xla_local_launch_op.cc"], + hdrs = ["xla_local_launch_op.h"], + deps = [ + "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:xla_compilation_cache", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:tensorflow_opensource", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_device_launch_op", + srcs = ["xla_device_launch_op.cc"], + hdrs = ["xla_device_launch_op.h"], + deps = [ + "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:xla_compilation_cache", + "//tensorflow/compiler/jit:xla_device", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core/kernels:variable_ops", + ], +) + +cc_library( + name = "parallel_check_op", + srcs = ["parallel_check_op.cc"], + visibility = ["//tensorflow/compiler/jit:friends"], + deps = [ + "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc new file mode 100644 index 00000000000..c86e03118b5 --- /dev/null +++ b/tensorflow/compiler/jit/kernels/parallel_check_op.cc @@ -0,0 +1,144 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace { + +// Inputs 2*N tensors, outputs the first N inputs. +// Logs errors if input tensor i and i + N are not (near) identical +// in any position. +class ParallelCheckOp : public OpKernel { + public: + explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + template + int CompareTensors(DataType dtype, const char* v0, const char* v1, + int64 num_elts, int input_idx) { + int failed = 0; + const T* p0 = reinterpret_cast(v0); + const T* p1 = reinterpret_cast(v1); + double rtol; + legacy_flags::ParallelCheckOpFlags* flags = + legacy_flags::GetParallelCheckOpFlags(); + if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(), + &rtol)) { + LOG(ERROR) << "can't convert parallel_check_rtol " + << flags->parallel_check_rtol << " to double"; + } + double atol; + if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(), + &atol)) { + LOG(ERROR) << "can't convert parallel_check_atol " + << flags->parallel_check_atol << " to double"; + } + for (int i = 0; i < num_elts; ++i) { + bool ok = (p0[i] == p1[i]); + VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i]; + if (!ok) { + if (std::is_same::value || std::is_same::value) { + float tolerance = + std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i]))); + T diff = p0[i] - p1[i]; + if (diff < 0) diff = 0 - diff; + ok = (diff <= tolerance); + } + if (ok) continue; + LOG(ERROR) << "Op " << def().name() << " fails equality at output " + << input_idx << " type " << DataTypeString(dtype) + << " element " << i << ": std_val=" << p0[i] + << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); + if (++failed > 10) break; + } + } + return failed; + } + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "Compute " << def().name(); + const int num_pairs = ctx->num_inputs() / 2; + for (int i = 0; i < num_pairs; ++i) { + CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); + Tensor t0 = ctx->input(i); + Tensor t1 = ctx->input(i + num_pairs); + int64 num_elts = t0.NumElements(); + CHECK_EQ(num_elts, t1.NumElements()); + + // Compare inputs elementwise for near-exact equality. + const char* v0 = t0.tensor_data().data(); + const char* v1 = t1.tensor_data().data(); + int failed = 0; + switch (ctx->input_dtype(i)) { + case DT_INT32: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + case DT_INT64: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + case DT_FLOAT: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + case DT_DOUBLE: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + case DT_BOOL: + failed = + CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); + break; + default: + LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); + } + if (failed > 0) { + LOG(ERROR) << "check failed for " << def().name() << " output " << i + << " num_elts: " << num_elts; + legacy_flags::ParallelCheckOpFlags* flags = + legacy_flags::GetParallelCheckOpFlags(); + if (flags->parallel_check_failfast) { + LOG(QFATAL) << "failfast on first parallel-check failure"; + } + } else { + VLOG(1) << "check passed for " << def().name() << " output " << i + << " num_elts: " << num_elts; + } + + // Propagate the std value. + if (IsRefType(ctx->input_dtype(i))) { + ctx->forward_ref_input_to_ref_output(i, i); + } else { + ctx->set_output(i, ctx->input(i)); + } + } + } + + TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp); +}; + +REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU), + ParallelCheckOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc new file mode 100644 index 00000000000..29c5ff72429 --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc @@ -0,0 +1,253 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h" + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +namespace { + +Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** cache) { + XlaDevice::Metadata* metadata; + Status s = rm->Lookup(rm->default_container(), + "xla_metadata", &metadata); + if (!s.ok()) { + return s; + } + core::ScopedUnref metadata_ref(metadata); + *cache = + new XlaCompilationCache(metadata->client(), metadata->jit_device_type()); + return Status::OK(); +} + +} // namespace + +XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); + function_ = *func; + VLOG(1) << "XlaDeviceLaunch created function=" + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); + DataTypeVector constant_types; + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); + num_constant_args_ = constant_types.size(); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_)); +} + +std::vector SnapshotResourceVariables(OpKernelContext* ctx, + int num_variables) { + std::vector snapshot(num_variables); + int first_variable = ctx->num_inputs() - num_variables; + for (int i = 0; i < num_variables; ++i) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, first_variable + i); + if (LookupResource(ctx, handle, &variable).ok()) { + mutex_lock lock(*variable->mu()); + snapshot[i].name = handle.name(); + snapshot[i].present = true; + snapshot[i].value = *variable->tensor(); + } + } + return snapshot; +} + +void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaDeviceLaunch::Compute " + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); + + XlaCompilationCache* cache; + OP_REQUIRES_OK(ctx, rm->LookupOrCreate( + rm->default_container(), "xla_compiler", &cache, + [rm](XlaCompilationCache** cache) { + return BuildCompilationCache(rm, cache); + })); + // Holds the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + + std::vector variables = + SnapshotResourceVariables(ctx, num_resource_args_); + + XlaCompiler::Options options; + options.client = cache->client(); + options.device_type = &cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = false; + options.local_executable_has_hybrid_result = false; + + const XlaCompiler::CompilationResult* kernel; + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, + variables, ctx, &kernel, nullptr)); + + VLOG(1) << "XLA compilation complete..."; + + OP_REQUIRES(ctx, ctx->num_outputs() == kernel->outputs.size(), + errors::Internal("Unexpected number of outputs")); + + // Runs the computation, if any. There might not be a computation if all + // outputs were compile-time constants. + std::vector> outputs; + if (!kernel->computation->IsNull()) { + auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); + + // Builds the inputs to the computation. + std::vector> arg_handles( + kernel->input_mapping.size()); + std::vector arg_ptrs(kernel->input_mapping.size()); + + // Adds the argument tensors. + const int first_variable_arg = ctx->num_inputs() - num_resource_args_; + for (int i = 0; i < kernel->input_mapping.size(); ++i) { + int op_input_num = kernel->input_mapping[i]; + + if (op_input_num >= first_variable_arg) { + arg_handles[i] = XlaTransferManager::GetTensorGlobalData( + variables[op_input_num - first_variable_arg].value); + } else { + arg_handles[i] = + XlaTransferManager::GetTensorGlobalData(ctx->input(op_input_num)); + } + arg_ptrs[i] = arg_handles[i].get(); + } + + // Execute the computation. + xla::ExecutionProfile profile; + xla::ExecutionOptions execution_options; + *execution_options.mutable_shape_with_output_layout() = + kernel->xla_output_shape; + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + VLOG(1) << "Executing XLA Computation..."; + auto result = cache->client()->Execute(*kernel->computation, arg_ptrs, + &execution_options, &profile); + auto elapsed = env->NowMicros() - start_time; + OP_REQUIRES(ctx, result.ok(), result.status()); + + VLOG(1) << "Elapsed time: " << elapsed << "us"; + VLOG(1) << "ExecutionProfile: " << profile.DebugString(); + + if (xla::ShapeUtil::IsTuple(kernel->xla_output_shape)) { + auto outputs_or_error = + cache->client()->DeconstructTuple(*result.ValueOrDie()); + OP_REQUIRES(ctx, outputs_or_error.ok(), outputs_or_error.status()); + outputs = outputs_or_error.ConsumeValueOrDie(); + } else { + outputs.push_back(result.ConsumeValueOrDie()); + } + } + + XlaDeviceContext* device_context = ctx->op_device_context(); + + // Copy XLA outputs to the operator's outputs. + VLOG(2) << "Setting operator output"; + int output_num = 0; + for (int i = 0; i < ctx->num_outputs(); ++i) { + Tensor* output; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(i, kernel->outputs[i].shape, &output)); + + if (kernel->outputs[i].is_constant) { + // TODO(phawkins): mark constant _XlaLaunch outputs as HostMemory and + // remove the copy from this code. + Status status; + device_context->CopyCPUTensorToDevice( + &kernel->outputs[i].constant_value, nullptr, output, + [&status](const Status& s) { status = s; }); + if (!status.ok()) { + ctx->SetStatus(status); + return; + } + } else { + CHECK_LT(output_num, outputs.size()); + XlaTransferManager::SetTensorGlobalData( + std::shared_ptr(std::move(outputs[output_num])), + output); + ++output_num; + } + } + + // Apply variable updates, if any. + VLOG(2) << "Applying variable updates"; + for (int i = 0; i < kernel->variable_updates.size(); ++i) { + const XlaCompiler::VariableUpdate& write = kernel->variable_updates[i]; + OP_REQUIRES(ctx, + write.input_index >= 0 && write.input_index < ctx->num_inputs(), + errors::Internal("Invalid input index for variable write.")); + // This code is very close to being a clone of AssignVariableOp, but the + // key difference is that the contents of an XLA device tensor cannot be + // copied safely; instead we must use + // XlaTransferManager::SetTensorGlobalData. + Var* variable = nullptr; + // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not + // a Tensor. + OP_REQUIRES_OK(ctx, LookupOrCreateResource( + ctx, HandleFromInput(ctx, write.input_index), + &variable, [this, ctx, &write](Var** ptr) { + *ptr = new Var(write.type); + PersistentTensor unused; + Tensor* tmp; + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + write.type, write.shape, &unused, &tmp)); + *(*ptr)->tensor() = *tmp; + return Status::OK(); + })); + core::ScopedUnref s(variable); + + mutex_lock ml(*variable->mu()); + OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, + errors::Internal("Mismatched type in variable write")); + if (!variable->tensor()->shape().IsSameSize(write.shape)) { + PersistentTensor unused; + Tensor* tmp; + OP_REQUIRES_OK(ctx, ctx->allocate_persistent(write.type, write.shape, + &unused, &tmp)); + *variable->tensor() = *tmp; + } + XlaTransferManager::SetTensorGlobalData( + std::shared_ptr(std::move(outputs[output_num])), + variable->tensor()); + ++output_num; + } + + VLOG(1) << "Done"; +} + +XlaDeviceLaunchOp::~XlaDeviceLaunchOp() { + VLOG(1) << "XlaDeviceLaunch destroyed"; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.h b/tensorflow/compiler/jit/kernels/xla_device_launch_op.h new file mode 100644 index 00000000000..65516163c91 --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_DEVICE_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_DEVICE_LAUNCH_OP_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// Takes a snapshot of the values of resource variable arguments, which are +// the last `num_variables` arguments. We snapshot tensors that back +// resource variables since concurrent updates may modify the shape, and it is +// important that the shapes used for compilation match the true shapes of the +// buffers. +std::vector SnapshotResourceVariables(OpKernelContext* ctx, + int num_variables); + +// The XlaDeviceLaunchOp is used to replace a region of the TensorFlow graph +// which will be compiled and executed using XLA. The XlaDeviceLaunchOp is +// responsible for handling interactions with the TensorFlow executor. +// Once all inputs are present, and their shapes are known, the op can +// use a 'TlaJit' to compile and execute code which is specific +// to the shapes of input Tensors. +class XlaDeviceLaunchOp : public OpKernel { + public: + explicit XlaDeviceLaunchOp(OpKernelConstruction* ctx); + ~XlaDeviceLaunchOp() override; + + void Compute(OpKernelContext* ctx) override; + + private: + NameAttrList function_; + + // Number of compile-time constant arguments. + int num_constant_args_; + + // Number of resource variable arguments. + int num_resource_args_; + + Tensor dummy_tensor_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceLaunchOp); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_DEVICE_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc similarity index 76% rename from tensorflow/compiler/jit/xla_local_launch_op.cc rename to tensorflow/compiler/jit/kernels/xla_local_launch_op.cc index acf2ccb8903..40acc0d81d0 100644 --- a/tensorflow/compiler/jit/xla_local_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/xla_local_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_local_launch_op.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/statusor.h" @@ -37,20 +38,9 @@ namespace gpu = perftools::gputools; namespace tensorflow { -REGISTER_OP("_XlaLaunch") - .Input("constants: Tconstants") - .Attr("Tconstants: list(type) >= 0") - .Input("args: Targs") - .Attr("Targs: list(type) >= 0") - .Output("results: Tresults") - .Attr("Tresults: list(type) >= 0") - .Attr("function: func") - // XLA random-number generation ops are stateful. - // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch. - .SetIsStateful() - .Doc("XLA Launch Op. For use by the XLA JIT only."); - // Adapter class that wraps a Tensorflow allocator as an XLA allocator. +// Assumes that the Tensorflow allocator permits asynchronous deallocation: +// see comment on `AllowsAsynchronousDeallocation()`. class XlaAllocator : public xla::DeviceMemoryAllocator { public: XlaAllocator(const perftools::gputools::Platform* platform, @@ -66,6 +56,15 @@ class XlaAllocator : public xla::DeviceMemoryAllocator { Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype, const TensorShape& shape, Tensor* tensor) const; + // The Tensorflow BFC allocator used on GPU allows host-side deallocation + // before GPU execution takes place. Tensorflow uses the ordering of the main + // compute stream to enforce a happens-before relationship between a memory + // allocation and code that reuses the same memory. If Tensorflow adds + // support for multiple GPU streams or allocators with different ordering + // requirements, this code may need to change. + // (This attribute has no effect on CPU.) + bool AllowsAsynchronousDeallocation() const override { return true; } + private: OpKernelContext* const op_context_; @@ -143,45 +142,51 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) DataTypeVector constant_types; OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); num_constant_args_ = constant_types.size(); + + int num_resource_args; + OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args)); + OP_REQUIRES(ctx, num_resource_args == 0, + errors::Unimplemented( + "XlaLocalLaunchOp does not support resource variables")); + if (device_type_ == DeviceType(DEVICE_CPU)) { + platform_id_ = gpu::host::kHostPlatformId; + } else if (device_type_ == DeviceType(DEVICE_GPU)) { + platform_id_ = gpu::cuda::kCudaPlatformId; + } else { + ctx->SetStatus( + errors::InvalidArgument("Unknown device type for local _XlaLaunch")); + return; + } } -Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) { - gpu::Platform::Id platform_id; - if (device_type_ == DeviceType(DEVICE_CPU)) { - platform_id = gpu::host::kHostPlatformId; - } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id = gpu::cuda::kCudaPlatformId; - } else { - return errors::InvalidArgument("Unknown device type for local _XlaLaunch"); - } - - auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id); +Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { + auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_); if (!platform.ok()) { return StreamExecutorUtil::ConvertStatus(platform.status()); } - auto client = - xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()); + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); if (!client.ok()) { return client.status(); } - const string* compiler_device; - if (!XlaOpRegistry::GetJitDevice(device_type_.type(), &compiler_device, - /*requires_jit=*/nullptr)) { + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(), + ®istration)) { return errors::InvalidArgument("No JIT device registered for ", device_type_.type()); } - XlaCompiler::Options options; - options.device_type = DeviceType(*compiler_device); - options.client = client.ValueOrDie(); - options.allow_cpu_custom_calls = (platform_id == gpu::host::kHostPlatformId); - options.local_executable_has_hybrid_result = true; - *compiler = new XlaCompilationCache(options); + *cache = new XlaCompilationCache( + client.ValueOrDie(), DeviceType(registration->compilation_device_name)); return Status::OK(); } void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaLocalLaunchOp::Compute " - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); @@ -190,25 +195,31 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { gpu::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - XlaCompilationCache* compiler; - OP_REQUIRES_OK(ctx, - rm->LookupOrCreate( - rm->default_container(), "xla_compiler", &compiler, - [this](XlaCompilationCache** compiler) { - return BuildCompilationCache(compiler); - })); + XlaCompilationCache* cache; + OP_REQUIRES_OK(ctx, rm->LookupOrCreate( + rm->default_container(), "xla_cache", &cache, + [this, ctx](XlaCompilationCache** cache) { + return BuildCompilationCache(ctx, cache); + })); // Hold the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but // this is more obviously correct.) - core::ScopedUnref compiler_ref(compiler); + core::ScopedUnref cache_ref(cache); - xla::LocalClient* client = static_cast(compiler->client()); + xla::LocalClient* client = static_cast(cache->client()); + + XlaCompiler::Options options; + options.client = client; + options.device_type = &cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); + options.local_executable_has_hybrid_result = true; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - OP_REQUIRES_OK(ctx, - compiler->Compile(function_, num_constant_args_, ctx, &kernel, - &executable)); + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, {}, + ctx, &kernel, &executable)); VLOG(1) << "Executing XLA Computation..."; @@ -218,7 +229,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { std::unique_ptr output; bool output_is_tuple; - if (!kernel->computation.IsNull()) { + if (!kernel->computation->IsNull()) { // Build xla::ShapedBuffers that point directly to the Tensor buffers. std::vector> arg_buffers; arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); @@ -227,8 +238,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // Pass remaining parameters. for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { - int arg_num = kernel->xla_input_shapes[i].first; - const xla::Shape& shape = kernel->xla_input_shapes[i].second; + int arg_num = kernel->input_mapping[i]; + const xla::Shape& shape = kernel->xla_input_shapes[i]; gpu::DeviceMemoryBase dmem( const_cast(ctx->input(arg_num).tensor_data().data()), ctx->input(arg_num).tensor_data().size()); @@ -316,10 +327,9 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { } Tensor output_tensor; // Looks up the owning Tensor by buffer address. - OP_REQUIRES_OK( - ctx, - xla_allocator.MakeTensorFromBuffer( - buffer, ctx->expected_output_dtype(i), shape, &output_tensor)); + OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer( + buffer, ctx->expected_output_dtype(i), shape, + &output_tensor)); ctx->set_output(i, output_tensor); ++output_num; } diff --git a/tensorflow/compiler/jit/xla_local_launch_op.h b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h similarity index 76% rename from tensorflow/compiler/jit/xla_local_launch_op.h rename to tensorflow/compiler/jit/kernels/xla_local_launch_op.h index 96ae664cbe2..5e4d3336a91 100644 --- a/tensorflow/compiler/jit/xla_local_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_ -#define TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_ +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_ #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/core/framework/allocator.h" @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { @@ -31,8 +32,9 @@ namespace tensorflow { // Once all inputs are present, and their shapes are known, the op can // use a 'XlaCompilationCache' to compile and execute code which is specific // to the shapes of input Tensors. -// XlaLocalLaunchOp uses xla::LocalClient::ExecuteLocally and passes -// arguments into/out of XLA in device memory. +// XlaLocalLaunchOp uses xla::LocalClient::Compile() and +// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device +// memory. class XlaLocalLaunchOp : public OpKernel { public: explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); @@ -42,14 +44,18 @@ class XlaLocalLaunchOp : public OpKernel { private: // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(XlaCompilationCache** compiler); + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** compiler); DeviceType device_type_; NameAttrList function_; int num_constant_args_; + + perftools::gputools::Platform::Id platform_id_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_ +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 725c969c051..f1fef85f994 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -24,8 +24,9 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" @@ -50,22 +51,24 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { } // Make sure we don't recurse infinitely on recursive functions. -const int kMaxRecursionDepth = 5; +const int kMaxRecursionDepth = 10; -bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, - int depth, FunctionLibraryRuntime* lib_runtime); +bool IsCompilableCall(const NodeDef& call_def, + const DeviceType& jit_device_type, int depth, + FunctionLibraryRuntime* lib_runtime); -// Tests whether 'while_def' is a completely compilable loop. +// Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. -bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, - int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Loop marking: " << while_def.op(); +bool IsCompilableWhile(const Node& while_node, + const DeviceType& jit_device_type, int depth, + FunctionLibraryRuntime* lib_runtime) { + VLOG(2) << "Loop marking: " << while_node.type_string(); const NameAttrList* name_attr; NodeDef call; Status status; - status = GetNodeAttr(while_def, "cond", &name_attr); + status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); if (!status.ok()) { VLOG(2) << "Missing 'cond' attribute on While node."; return false; @@ -78,7 +81,7 @@ bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, VLOG(2) << "Can't compile loop condition: " << cond_func; return false; } - status = GetNodeAttr(while_def, "body", &name_attr); + status = GetNodeAttr(while_node.attrs(), "body", &name_attr); if (!status.ok()) { VLOG(2) << "Missing 'body' attribute on While node."; return false; @@ -98,8 +101,9 @@ bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, // Tests whether 'call_def' is a call to a completely compilable function. // Every operator in the function must be compilable for a function to be // compilable. -bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, - int depth, FunctionLibraryRuntime* lib_runtime) { +bool IsCompilableCall(const NodeDef& call_def, + const DeviceType& jit_device_type, int depth, + FunctionLibraryRuntime* lib_runtime) { VLOG(2) << "Function marking: " << call_def.op(); if (depth > kMaxRecursionDepth) { @@ -109,21 +113,32 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, FunctionLibraryRuntime::Handle handle; Status status = - lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle); + lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); if (!status.ok()) { VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; return false; } const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); CHECK(fbody); + const FunctionDef& fdef = fbody->fdef; + bool noinline = false; + if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() && + noinline) { + // The underlying mechanism that calls non-inlined functions uses + // LocalExecutor, which interacts poorly with the LocalExecutor used by + // tf2xla to translate the TF graph into XLA. So we avoid this for now. + // + // TODO(b/36139787): Create a mechanism to set inlining hints. + VLOG(2) << "Can't compile noinline function: " << fdef.DebugString(); + return false; + } - for (Node* node : fbody->graph->nodes()) { - if (node->IsSource() || node->IsSink()) continue; - if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue; - if (node->def().op() == "While") { + for (Node* node : fbody->graph->op_nodes()) { + if (node->type_string() == "_Arg" || node->type_string() == "_Retval") + continue; + if (node->type_string() == "While") { // Handle functional While loop (not in open source build). - return IsCompilableWhile(node->def(), jit_device_type, depth + 1, - lib_runtime); + return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime); } if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, depth + 1, @@ -147,6 +162,12 @@ Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { return Status::OK(); } +// Does `node` have a DT_RESOURCE typed argument? +bool HasResourceArgument(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end(); +} + Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, @@ -155,28 +176,30 @@ Status FindCompilationCandidates( std::unique_ptr lib_runtime(NewFunctionLibraryRuntime( nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts)); - for (Node* node : graph.nodes()) { - if (node->IsSource() || node->IsSink()) continue; - + for (Node* node : graph.op_nodes()) { DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue; - const string* jit_device_name; - CHECK(XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device_name, - /*requires_jit=*/nullptr)); - DeviceType jit_device_type(*jit_device_name); + const XlaOpRegistry::DeviceRegistration* registration; + CHECK( + XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); + DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) { VLOG(2) << "Compilation rejected node: unsupported op " << node->name() - << ": " << node->def().op(); + << ": " << node->type_string(); continue; } - if (node->def().op() == "While" && - !IsCompilableWhile(node->def(), jit_device_type, 0, - lib_runtime.get())) { + if (!registration->compile_resource_ops && HasResourceArgument(*node)) { + VLOG(2) << "Compilation rejected node: resource argument " << node->name() + << ": " << node->type_string(); + continue; + } + if (node->type_string() == "While" && + !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) { continue; } candidates->insert(node); @@ -184,85 +207,27 @@ Status FindCompilationCandidates( return Status::OK(); } -// Union-Find data structure used to compute clusters. We use our own -// implementation because we want one key feature: when merging clusters, we -// need to know which value becomes the representative of the merged clusters. -// We use the representatives to name nodes in a cycle detection graph, and we -// need to control which node is named. -// TODO(phawkins): consider merging this code with union-find implementations -// in Tensorflow, e.g., in SimplePlacer. -class Cluster { - public: - Cluster(); - - int Size() { return FindRoot()->size_; } - - // Merges this cluster with 'other'. This cluster's representative becomes - // the representative of the merged cluster; the representative of 'other' - // is ignored. - void Merge(Cluster* other); - - // Each cluster has an associated integer 'representative', initialized to -1 - // by default. - int GetRepresentative() { return FindRoot()->representative_; } - void SetRepresentative(int representative) { - FindRoot()->representative_ = representative; - } - - private: - // Finds the root element of the cluster. Performs path compression. - Cluster* FindRoot(); - - int representative_; - int rank_; - int size_; // Size of the cluster. - Cluster* parent_; +struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; }; -Cluster::Cluster() - : representative_(-1), rank_(0), size_(1), parent_(nullptr) {} - -void Cluster::Merge(Cluster* other) { - Cluster* a = FindRoot(); - Cluster* b = other->FindRoot(); - if (a == b) return; - if (a->rank_ > b->rank_) { - b->parent_ = a; - a->size_ += b->size_; - return; - } - - a->parent_ = b; - if (a->rank_ == b->rank_) { - b->rank_++; - } - b->representative_ = a->representative_; - b->size_ += a->size_; -} - -Cluster* Cluster::FindRoot() { - if (!parent_) return this; - // Path compression: update intermediate nodes to point to the root of the - // equivalence class. - parent_ = parent_->FindRoot(); - return parent_; -} - } // anonymous namespace bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { Device* device = flr->device(); - const string* jit_device_name; - CHECK(XlaOpRegistry::GetJitDevice(device->device_type(), &jit_device_name, - /*requires_jit=*/nullptr)); - DeviceType jit_device_type(*jit_device_name); + const XlaOpRegistry::DeviceRegistration* registration; + CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), + ®istration)); + DeviceType jit_device_type(registration->compilation_device_name); return IsCompilableCall(ndef, jit_device_type, 0, flr); } Status MarkForCompilationPass::Run( const GraphOptimizationPassOptions& options) { - // TODO(phawkins): precompute the "GetJitDevice" properties each device ahead - // of time. + // TODO(phawkins): precompute the "GetCompilationDevice" properties of each + // device ahead of time. OptimizerOptions::GlobalJitLevel global_jit_level = options.session_options->config.graph_options() .optimizer_options() @@ -283,25 +248,24 @@ Status MarkForCompilationPass::Run( const FunctionLibraryDefinition* fld = options.flib_def; auto is_compilable = [global_jit_level, fld](const Node* node, const DeviceType& device_type) { - const string* jit_device; - bool requires_jit; - if (!XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device, - &requires_jit)) { + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), + ®istration)) { return false; } // If this device requires a JIT, we must say yes. - if (requires_jit) return true; + if (registration->requires_compilation) return true; // If there is a _XlaCompile annotation, use its value. bool compile = false; - Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile); + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); if (status.ok()) return compile; - status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile); + status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; // Otherwise use the value of global_jit_level. - return global_jit_level > 0; + return registration->enable_jit_by_default && global_jit_level > 0; }; return RunImpl(options, is_compilable); } @@ -323,7 +287,7 @@ Status MarkForCompilationPass::RunImpl( VLOG(1) << "MarkForCompilationPass::Run"; // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterJitKernels(); + XlaOpRegistry::RegisterCompilationKernels(); Graph* graph = options.graph->get(); @@ -411,10 +375,11 @@ Status MarkForCompilationPass::RunImpl( // Each compilation candidate belongs to a cluster. The cluster's // representative // names the node in the 'cycles' graph that represents the cluster. - std::vector clusters(graph->num_node_ids()); - std::deque worklist; + std::vector> clusters(graph->num_node_ids()); + std::deque*> worklist; for (Node* node : compilation_candidates) { - clusters[node->id()].SetRepresentative(node->id()); + Cluster& cluster = clusters[node->id()].Get(); + cluster.representative = node->id(); worklist.push_back(&clusters[node->id()]); } @@ -424,15 +389,19 @@ Status MarkForCompilationPass::RunImpl( // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. while (!worklist.empty()) { - int from = worklist.front()->GetRepresentative(); + int from = worklist.front()->Get().representative; worklist.pop_front(); Node* node_from = graph->FindNodeId(from); if (node_from->IsControlFlow()) { // Control flow nodes aren't compilation candidates and should never // appear. - return errors::Internal("Found control flow node in clustering worklist"); + return errors::Internal( + "Found control flow node in clustering worklist: ", + node_from->type_string()); } + string from_scope; + string to_scope; for (int to : cycles.Successors(from)) { if (to >= graph->num_node_ids()) { // Node is a "frame" node that is present only in the cycle detection @@ -440,10 +409,27 @@ Status MarkForCompilationPass::RunImpl( continue; } Node* node_to = graph->FindNodeId(to); - if (compilation_candidates.find(node_to) == compilation_candidates.cend()) + if (compilation_candidates.find(node_to) == + compilation_candidates.cend()) { continue; - if (node_from->assigned_device_name() != node_to->assigned_device_name()) + } + if (node_from->assigned_device_name() != + node_to->assigned_device_name()) { continue; + } + // Look for an _XlaScope on both nodes. If both nodes have a + // scope and the scopes do not match, do not cluster along this + // edge. If even one of the nodes lacks an _XlaScope attribute, + // then it is treated as a "bridge" and a cluster may be created + // along it. We may want to restrict this behavior to require + // all nodes marked with _XlaCompile=true to also have a + // _XlaScope property set (and raise an error otherwise); but + // for now we don't do this. + if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && + GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && + from_scope != to_scope) { + continue; + } // Ops that consume shapes cannot be the root of a cluster. This is an // optimization. @@ -476,7 +462,7 @@ Status MarkForCompilationPass::RunImpl( // Count the number of elements in each cluster. std::vector cluster_sizes(graph->num_node_ids()); for (const Node* n : compilation_candidates) { - int cluster = clusters[n->id()].GetRepresentative(); + int cluster = clusters[n->id()].Get().representative; cluster_sizes[cluster]++; } @@ -490,32 +476,30 @@ Status MarkForCompilationPass::RunImpl( // if compilation is enabled, otherwise there will be no such candidates). const int min_cluster_size = flags->tf_xla_min_cluster_size; for (Node* n : compilation_candidates) { - int cluster = clusters[n->id()].GetRepresentative(); + int cluster = clusters[n->id()].Get().representative; // Compile if the user marked this node _XlaCompile=true bool compile_attr = false; bool marked_for_compilation = false; - if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) { + if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) { marked_for_compilation = compile_attr; - } else if (options.flib_def - ->GetAttr(n->def(), kXlaCompileAttr, &compile_attr) + } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr) .ok()) { marked_for_compilation = compile_attr; } // Compile if this operator is placed on a device that requires // compilation. - bool requires_jit = false; DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceTypeOfDevice(n->assigned_device_name(), &device_type)); - XlaOpRegistry::GetJitDevice(device_type.type(), - /*jit_device_name=*/nullptr, &requires_jit); + const XlaOpRegistry::DeviceRegistration* registration; + XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); // Or compile if this is a cluster of >= min_cluster_size compilable // operators. if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation || - requires_jit) { + registration->requires_compilation) { string& name = cluster_names[cluster]; if (name.empty()) { name = strings::StrCat("cluster_", cluster_sequence_num++); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 61b2031a36e..9f30e12e0e3 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" @@ -56,7 +57,7 @@ std::unordered_map GetClusters(const Graph& graph) { std::unordered_map ids; for (Node* node : graph.nodes()) { string cluster; - if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) { + if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) { CHECK(!cluster.empty()); ids[node->name()] = cluster; } @@ -77,7 +78,7 @@ TEST(XlaCompilationTest, Chains) { ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); ops::UnaryOp("Relu", e, builder.opts().WithName("F")); - builder.ToGraph(graph.get()); + TF_EXPECT_OK(builder.ToGraph(graph.get())); } MarkForCompilation(&graph); @@ -102,7 +103,7 @@ TEST(XlaCompilationTest, UncompilableCycles) { Node* b = ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - builder.ToGraph(graph.get()); + TF_EXPECT_OK(builder.ToGraph(graph.get())); } MarkForCompilation(&graph); @@ -122,7 +123,7 @@ TEST(XlaCompilationTest, CompilableCycles) { .WithAttr("value", Tensor())); Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - builder.ToGraph(graph.get()); + TF_EXPECT_OK(builder.ToGraph(graph.get())); } MarkForCompilation(&graph); @@ -145,7 +146,7 @@ TEST(XlaCompilationTest, UnsupportedTypes) { .WithAttr("value", Tensor(DT_COMPLEX64, TensorShape()))); Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - builder.ToGraph(graph.get()); + TF_EXPECT_OK(builder.ToGraph(graph.get())); } MarkForCompilation(&graph); @@ -174,7 +175,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { concat_builder.Input(dim).Input({a, a}).Attr("N", 2); builder.opts().FinalizeBuilder(&concat_builder); - builder.ToGraph(graph.get()); + TF_EXPECT_OK(builder.ToGraph(graph.get())); } MarkForCompilation(&graph); @@ -183,13 +184,20 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { } TEST(XlaCompilationTest, FunctionCalls) { - FunctionDefLibrary flib; - *flib.add_function() = FunctionDefHelper::Define( + FunctionDef compilable = FunctionDefHelper::Define( "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}}); - *flib.add_function() = + FunctionDef uncompilable = FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"}, {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}}); + FunctionDef noinline = compilable; + noinline.mutable_signature()->set_name("NoInlineFn"); + AddAttr("_noinline", bool(true), noinline.mutable_attr()); + + FunctionDefLibrary flib; + *flib.add_function() = compilable; + *flib.add_function() = uncompilable; + *flib.add_function() = noinline; FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); std::unique_ptr graph(new Graph(&flib_def)); @@ -201,7 +209,8 @@ TEST(XlaCompilationTest, FunctionCalls) { Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B")); Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D")); - builder.ToGraph(graph.get()); + ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E")); + TF_EXPECT_OK(builder.ToGraph(graph.get())); } MarkForCompilation(&graph, &flib_def); @@ -212,6 +221,7 @@ TEST(XlaCompilationTest, FunctionCalls) { EXPECT_EQ(clusters["B"], clusters["C"]); EXPECT_TRUE(clusters.find("A") == clusters.cend()); EXPECT_TRUE(clusters.find("D") == clusters.cend()); + EXPECT_TRUE(clusters.find("E") == clusters.cend()); } // Metadata-only operators such as Shape/Rank/Size may not be the root of a @@ -231,8 +241,8 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B")); Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C")); Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D")); - ops::UnaryOp("Shape", d, builder.opts().WithName("C")); - builder.ToGraph(graph.get()); + ops::UnaryOp("Shape", d, builder.opts().WithName("E")); + TF_EXPECT_OK(builder.ToGraph(graph.get())); } MarkForCompilation(&graph); auto clusters = GetClusters(*graph); @@ -318,7 +328,7 @@ TEST(XlaCompilationTest, SymbolicGradients) { d_builder.Input({c, c}); builder.opts().FinalizeBuilder(&d_builder); - builder.ToGraph(graph.get()); + TF_EXPECT_OK(builder.ToGraph(graph.get())); } MarkForCompilation(&graph); @@ -344,7 +354,7 @@ TEST(XlaCompilationTest, Loops) { auto d = ops::Add(root.WithOpName("D"), c, exit); std::unique_ptr graph(new Graph(OpRegistry::Global())); - root.ToGraph(graph.get()); + TF_EXPECT_OK(root.ToGraph(graph.get())); MarkForCompilation(&graph); auto clusters = GetClusters(*graph); @@ -354,5 +364,96 @@ TEST(XlaCompilationTest, Loops) { EXPECT_EQ(0, clusters.size()); } +TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor()) + .WithAttr(kXlaScopeAttr, "ScopeA")); + Node* b = ops::UnaryOp( + "Relu", a, + builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); + ops::BinaryOp( + "MatMul", a, b, + builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); + TF_CHECK_OK(builder.ToGraph(graph.get())); + } + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + + // The computation is: C = A + relu(A) + // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. + // In this case, we cannot fuse anything, and there are no clusters. + EXPECT_EQ(0, clusters.size()); +} + +TEST(XlaCompilationTest, CyclesWithSplittingScopes) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor()) + .WithAttr(kXlaScopeAttr, "Scope1")); + Node* b = ops::UnaryOp( + "Relu", a, + builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "Scope1")); + Node* c = ops::BinaryOp( + "MatMul", a, b, + builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "Scope2")); + ops::BinaryOp( + "Add", b, c, + builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2")); + TF_CHECK_OK(builder.ToGraph(graph.get())); + } + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + + // The computation is: D = relu(A) + (A @ relu(A)) + // where A and relu(A) are in Scope1, and the @, + ops are in Scope2. + // In this case, we can fuse the A and relu(A), and we can fuse the + // second half of the operations; there are two clusters. + EXPECT_EQ(4, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_NE(clusters["A"], clusters["C"]); + EXPECT_EQ(clusters["C"], clusters["D"]); +} + +TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor()) + .WithAttr(kXlaScopeAttr, "ScopeA")); + Node* b = ops::UnaryOp( + "Relu", a, + builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + TF_CHECK_OK(builder.ToGraph(graph.get())); + } + + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + + // The computation is: C = A @ relu(A) + // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. + // In this case, we cannot fuse anything. + EXPECT_EQ(2, clusters.size()); + EXPECT_NE(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["B"], clusters["C"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD new file mode 100644 index 00000000000..8d1fa03cc0d --- /dev/null +++ b/tensorflow/compiler/jit/ops/BUILD @@ -0,0 +1,45 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//tensorflow/compiler/tf2xla:internal", + ], +) + +cc_library( + name = "xla_ops", + srcs = [ + "xla_ops.cc", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +cc_library( + name = "parallel_check_op", + srcs = ["parallel_check_op.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/jit/ops/parallel_check_op.cc b/tensorflow/compiler/jit/ops/parallel_check_op.cc new file mode 100644 index 00000000000..db5c1955788 --- /dev/null +++ b/tensorflow/compiler/jit/ops/parallel_check_op.cc @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("ParallelCheck") + .Attr("T: list(type) >= 0") + .Input("expected: T") + .Input("actual: T") + .Output("result: T") + .Doc(R"doc( +Op that compares two sets of inputs for near-identity, and propagates the first. +Inequality is logged to ERROR log. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc new file mode 100644 index 00000000000..07320b43dab --- /dev/null +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("_XlaLaunch") + .Input("constants: Tconstants") + .Attr("Tconstants: list(type) >= 0") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Input("resources: Nresources * resource") + .Attr("Nresources: int >= 0") + .Output("results: Tresults") + .Attr("Tresults: list(type) >= 0") + .Attr("function: func") + // XLA random-number generation ops are stateful. + // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch. + .SetIsStateful() + .Doc("XLA Launch Op. For use by the XLA JIT only."); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/parallel_check_op.cc b/tensorflow/compiler/jit/parallel_check_op.cc deleted file mode 100644 index d07da46ca04..00000000000 --- a/tensorflow/compiler/jit/parallel_check_op.cc +++ /dev/null @@ -1,154 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace { - -REGISTER_OP("ParallelCheck") - .Attr("T: list(type) >= 0") - .Input("expected: T") - .Input("actual: T") - .Output("result: T") - .Doc(R"doc( -Op that compares two sets of inputs for near-identity, and propagates the first. -Inequality is logged to ERROR log. -)doc"); - -// Inputs 2*N tensors, outputs the first N inputs. -// Logs errors if input tensor i and i + N are not (near) identical -// in any position. -class ParallelCheckOp : public OpKernel { - public: - explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - template - int CompareTensors(DataType dtype, const char* v0, const char* v1, - int64 num_elts, int input_idx) { - int failed = 0; - const T* p0 = reinterpret_cast(v0); - const T* p1 = reinterpret_cast(v1); - double rtol; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(), - &rtol)) { - LOG(ERROR) << "can't convert parallel_check_rtol " - << flags->parallel_check_rtol << " to double"; - } - double atol; - if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(), - &atol)) { - LOG(ERROR) << "can't convert parallel_check_atol " - << flags->parallel_check_atol << " to double"; - } - for (int i = 0; i < num_elts; ++i) { - bool ok = (p0[i] == p1[i]); - VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i]; - if (!ok) { - if (std::is_same::value || std::is_same::value) { - float tolerance = - std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i]))); - T diff = p0[i] - p1[i]; - if (diff < 0) diff = 0 - diff; - ok = (diff <= tolerance); - } - if (ok) continue; - LOG(ERROR) << "Op " << def().name() << " fails equality at output " - << input_idx << " type " << DataTypeString(dtype) - << " element " << i << ": std_val=" << p0[i] - << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); - if (++failed > 10) break; - } - } - return failed; - } - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "Compute " << def().name(); - const int num_pairs = ctx->num_inputs() / 2; - for (int i = 0; i < num_pairs; ++i) { - CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); - Tensor t0 = ctx->input(i); - Tensor t1 = ctx->input(i + num_pairs); - int64 num_elts = t0.NumElements(); - CHECK_EQ(num_elts, t1.NumElements()); - - // Compare inputs elementwise for near-exact equality. - const char* v0 = t0.tensor_data().data(); - const char* v1 = t1.tensor_data().data(); - int failed = 0; - switch (ctx->input_dtype(i)) { - case DT_INT32: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_INT64: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_FLOAT: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_DOUBLE: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_BOOL: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - default: - LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); - } - if (failed > 0) { - LOG(ERROR) << "check failed for " << def().name() << " output " << i - << " num_elts: " << num_elts; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (flags->parallel_check_failfast) { - LOG(QFATAL) << "failfast on first parallel-check failure"; - } - } else { - VLOG(1) << "check passed for " << def().name() << " output " << i - << " num_elts: " << num_elts; - } - - // Propagate the std value. - if (IsRefType(ctx->input_dtype(i))) { - ctx->forward_ref_input_to_ref_output(i, i); - } else { - ctx->set_output(i, ctx->input(i)); - } - } - } - - TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp); -}; - -REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU), - ParallelCheckOp); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/union_find.h b/tensorflow/compiler/jit/union_find.h new file mode 100644 index 00000000000..a1a7a6a4d0d --- /dev/null +++ b/tensorflow/compiler/jit/union_find.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ +#define TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ + +namespace tensorflow { + +// Union-Find data structure. +// Each cluster has an associated value; when merging clusters we can control +// which value becomes the representative of the merged clusters. Values must be +// copyable. +template +class UnionFind { + public: + UnionFind() : rank_(0), size_(1), parent_(nullptr) {} + + // Returns the number of elements in a cluster. + int Size() { return FindRoot()->size_; } + + // Merges this cluster with 'other'. This cluster's value becomes + // the value of the merged cluster; the value of 'other' is ignored. + void Merge(UnionFind* other); + + // Each cluster has an associated value. Retrieves the value associated + // with this cluster. + T& Get() { return FindRoot()->value_; } + + private: + // Finds the root element of the cluster. Performs path compression. + UnionFind* FindRoot(); + + int rank_; + int size_; // Size of the cluster. + UnionFind* parent_; + T value_; +}; + +template +void UnionFind::Merge(UnionFind* other) { + UnionFind* a = FindRoot(); + UnionFind* b = other->FindRoot(); + if (a == b) return; + if (a->rank_ > b->rank_) { + b->parent_ = a; + a->size_ += b->size_; + return; + } + + a->parent_ = b; + if (a->rank_ == b->rank_) { + b->rank_++; + } + b->value_ = a->value_; + b->size_ += a->size_; +} + +template +UnionFind* UnionFind::FindRoot() { + if (!parent_) return this; + // Path compression: update intermediate nodes to point to the root of the + // equivalence class. + parent_ = parent_->FindRoot(); + return parent_; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_UNION_FIND_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 4644121173e..63ca77f9a91 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -26,8 +26,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -35,9 +37,9 @@ limitations under the License. namespace tensorflow { -XlaCompilationCache::XlaCompilationCache(const XlaCompiler::Options& options) - : compiler_(options) {} - +XlaCompilationCache::XlaCompilationCache(xla::Client* client, + DeviceType device_type) + : client_(client), device_type_(std::move(device_type)) {} XlaCompilationCache::~XlaCompilationCache() = default; string XlaCompilationCache::DebugString() { @@ -54,7 +56,7 @@ string XlaCompilationCache::SignatureDebugString(const Signature& sig) { } for (const auto& v : sig.arg_values) { - strings::StrAppend(&result, "; ", v.first, ":", v.second.DebugString()); + strings::StrAppend(&result, "; ", v.DebugString()); } return result; } @@ -65,9 +67,7 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (arg_values.size() != other.arg_values.size()) return false; for (int i = 0; i < arg_values.size(); ++i) { - if (arg_values[i].first != other.arg_values[i].first || - arg_values[i].second.tensor_data() != - other.arg_values[i].second.tensor_data()) { + if (arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { return false; } } @@ -85,68 +85,159 @@ uint64 XlaCompilationCache::Signature::Hash::operator()( } } for (const auto& arg : signature.arg_values) { - h = Hash64Combine(h, std::hash()(static_cast(arg.first))); - h = Hash64Combine(h, Hash64(arg.second.tensor_data().data(), - arg.second.tensor_data().size())); + h = Hash64Combine( + h, Hash64(arg.tensor_data().data(), arg.tensor_data().size())); } return h; } +Status XlaCompilationCache::BuildSignature( + const NameAttrList& function, int num_constant_args, + const std::vector& variable_args, OpKernelContext* ctx, + Signature* signature) { + signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); + signature->arg_values.resize(num_constant_args); + + signature->arg_types.reserve(ctx->num_inputs() - num_constant_args); + + // Inputs are in the order: constants, non-constants, resource variables. + int input_num = 0; + // Use the values of compile time constants in the signature-> + while (input_num < num_constant_args) { + signature->arg_values[input_num] = ctx->input(input_num); + ++input_num; + } + // Add the types and shapes of the remaining arguments. + while (input_num < ctx->num_inputs() - variable_args.size()) { + signature->arg_types.emplace_back(ctx->input_dtype(input_num), + ctx->input(input_num).shape()); + ++input_num; + } + // For variable signatures, use the type and shape of the variable's + // current value. + for (const OptionalTensor& variable : variable_args) { + TF_RET_CHECK(input_num < ctx->num_inputs()); + if (variable.present) { + signature->arg_types.emplace_back(variable.value.dtype(), + variable.value.shape()); + } else { + signature->arg_types.emplace_back(DT_INVALID, TensorShape()); + } + ++input_num; + } + return Status::OK(); +} + namespace { // Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch // op. The first `num_constant_args` arguments must be host-memory Tensors. -std::vector BuildArguments(int num_constant_args, - OpKernelContext* ctx) { - std::vector args(ctx->num_inputs()); - int parameter_num = 0; - for (int i = 0; i < ctx->num_inputs(); ++i) { - args[i].type = ctx->input(i).dtype(); - args[i].shape = ctx->input(i).shape(); - if (i < num_constant_args || ctx->input(i).NumElements() == 0) { - args[i].parameter = -1; - args[i].constant_value = ctx->input(i); - } else { - args[i].parameter = parameter_num; - ++parameter_num; - } +Status BuildArguments(int num_constant_args, + const std::vector& variable_args, + OpKernelContext* ctx, + std::vector* args) { + args->resize(ctx->num_inputs()); + + int input_num = 0; + + // Handles compile-time constants. + TF_RET_CHECK(num_constant_args <= ctx->num_inputs()); + while (input_num < num_constant_args) { + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + XlaCompiler::Argument& arg = (*args)[input_num]; + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input.dtype(); + arg.shape = input.shape(); + arg.constant_value = input; + ++input_num; } - return args; + + // Handles the non-constant arguments. + int num_variable_args = variable_args.size(); + int num_nonconst_args = + ctx->num_inputs() - num_variable_args - num_constant_args; + TF_RET_CHECK(num_nonconst_args >= 0); + while (input_num < num_constant_args + num_nonconst_args) { + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + XlaCompiler::Argument& arg = (*args)[input_num]; + if (input.NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = input; + } + arg.type = input.dtype(); + arg.shape = input.shape(); + ++input_num; + } + + // Handles resource variables. + TF_RET_CHECK(input_num + num_variable_args == ctx->num_inputs()); + for (int variable_id = 0; variable_id < num_variable_args; ++variable_id) { + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() == DT_RESOURCE); + + XlaCompiler::Argument& arg = (*args)[input_num]; + + arg.name = variable_args[variable_id].name; + if (variable_args[variable_id].present) { + const Tensor& value = variable_args[variable_id].value; + arg.kind = XlaCompiler::Argument::kVariable; + arg.type = value.dtype(); + arg.shape = value.shape(); + } else { + // The values of uninitialized variables are not passed as inputs, since + // they are meaningless. However, it is legal to assign to a resource + // variable for the first time inside the XLA computation, so we do permit + // uninitialized variables. + arg.kind = XlaCompiler::Argument::kUninitializedVariable; + arg.type = DT_INVALID; + arg.shape = TensorShape(); + } + ++input_num; + } + + return Status::OK(); } } // namespace Status XlaCompilationCache::Compile( - const NameAttrList& function, int num_constant_args, OpKernelContext* ctx, + const XlaCompiler::Options& options, const NameAttrList& function, + int num_constant_args, const std::vector& variable_args, + OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { - std::vector argshapes; - VLOG(2) << "num_inputs = " << ctx->num_inputs() - << " num_constant_args= " << num_constant_args; + VLOG(2) << "num_inputs=" << ctx->num_inputs() + << " num_constant_args=" << num_constant_args + << " num_variable_args=" << variable_args.size(); for (int i = 0; i < ctx->num_inputs(); i++) { TensorShape shape = ctx->input(i).shape(); - VLOG(2) << i << ": dtype=" << ctx->input_dtype(i) + VLOG(2) << i << ": dtype=" << DataTypeString(ctx->input_dtype(i)) << " present=" << ctx->has_input(i) << " shape=" << shape.DebugString(); - argshapes.push_back(shape.DebugString()); + } + for (const OptionalTensor& variable : variable_args) { + VLOG(2) << "variable present=" << variable.present + << " type=" << DataTypeString(variable.value.dtype()) + << " shape=" << variable.value.shape().DebugString(); } VLOG(2) << "num_outputs = " << ctx->num_outputs(); for (int i = 0; i < ctx->num_outputs(); i++) { VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i); } } + + TF_RET_CHECK(num_constant_args + variable_args.size() <= ctx->num_inputs()); + Signature signature; - signature.name = Canonicalize(function.name(), function.attr()); - for (int i = 0; i < ctx->num_inputs(); ++i) { - signature.arg_types.emplace_back(ctx->input_dtype(i), - ctx->input(i).shape()); - if (i < num_constant_args) { - signature.arg_values.emplace_back(i, ctx->input(i)); - } - } + TF_RETURN_IF_ERROR(BuildSignature(function, num_constant_args, variable_args, + ctx, &signature)); VLOG(2) << "Signature: " << SignatureDebugString(signature); // The outer lock protects the existence of the cache entry. It does not @@ -169,24 +260,22 @@ Status XlaCompilationCache::Compile( if (!entry->compiled) { // Do the actual JIT compilation without holding the lock (it can take // a long time.) - std::vector args = - BuildArguments(num_constant_args, ctx); - - std::unique_ptr flr(NewFunctionLibraryRuntime( - compiler_.device_mgr(), ctx->env(), compiler_.device(), - TF_GRAPH_DEF_VERSION, - ctx->function_library()->GetFunctionLibraryDefinition(), - OptimizerOptions(), nullptr /* custom_kernel_creator */)); + std::vector args; + TF_RETURN_IF_ERROR( + BuildArguments(num_constant_args, variable_args, ctx, &args)); + XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = compiler_.CompileFunction( - flr.get(), function, args, &entry->compilation_result); + entry->compilation_status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), function, args, + &entry->compilation_result); } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { if (entry->executable == nullptr && - !entry->compilation_result.computation.IsNull()) { - entry->compilation_status = compiler_.BuildExecutable( + !entry->compilation_result.computation->IsNull()) { + XlaCompiler compiler(options); + entry->compilation_status = compiler.BuildExecutable( entry->compilation_result, &entry->executable); } *executable = entry->executable.get(); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 44d76db0fd4..4ffcb68a322 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -30,6 +29,13 @@ limitations under the License. namespace tensorflow { +// Struct that represents a possibly-absent Tensor. +struct OptionalTensor { + string name; // A descriptive name + bool present = false; // Is the tensor present? + Tensor value; // If present, what is the Tensor's value? +}; + // The XlaCompilationCache class caches the results of the XlaCompiler class, // which converts a Tensorflow graph into a compiled XLA compilation. // @@ -40,39 +46,47 @@ namespace tensorflow { // bound. class XlaCompilationCache : public ResourceBase { public: - explicit XlaCompilationCache(const XlaCompiler::Options& options); + XlaCompilationCache(xla::Client* client, DeviceType device_type); ~XlaCompilationCache() override; // Compiles a function into a XlaCompiler::CompilationResult that can be used - // to execute an XLA Computation. `compilation_result` must be non-null. - // If `executable` is non-null, also builds an xla::LocalExecutable and sets - // `executable to point to it. The resulting executable pointer may be null if - // the computation has no non-constant outputs. - // Compilation results are cached. - Status Compile(const NameAttrList& function, int num_constant_args, + // to execute an XLA Computation. Compilation results are cached. + // `function` is the name of a Tensorflow function to compile. + // `num_constant_args` is the number of compile-time constant arguments to + // `function`. `variable_args` is a snapshot of the current values of the + // resource variable arguments to `function`; uninitialized variables are + // represented by an absent OptionalTensor. + // The result of compilation is written to `*compilation_result`, which must + // be non-null. If `executable` is non-null, also builds an + // xla::LocalExecutable and sets `executable to point to it. The resulting + // executable pointer may be null if the computation has no non-constant + // outputs. + Status Compile(const XlaCompiler::Options& options, + const NameAttrList& function, int num_constant_args, + const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable); - xla::Client* client() const { return compiler_.client(); } + xla::Client* client() const { return client_; } + const DeviceType& device_type() const { return device_type_; } string DebugString() override; private: - XlaCompiler compiler_; - std::unique_ptr function_library_runtime_; + xla::Client* const client_; + const DeviceType device_type_; // Describes the types, shapes and any compile-time constant arguments - // to a kernel. + // to a kernel. Key that uniquely identifies a compilation output. struct Signature { string name; std::vector> arg_types; - // List of (argument #, value) pairs for arguments whose values are - // part of the JIT signature, and that are therefore constants in any given - // JIT compilation. Tensors must be in host memory. - std::vector> arg_values; + // List of Tensor values for compile-time constant arguments to the + // compilation, ordered by argument number. Tensors must be in host memory. + std::vector arg_values; bool operator==(const Signature& other) const; @@ -82,6 +96,11 @@ class XlaCompilationCache : public ResourceBase { }; static string SignatureDebugString(const Signature& sig); + // Builds the signature for a compilation. + Status BuildSignature(const NameAttrList& function, int num_constant_args, + const std::vector& variable_args, + OpKernelContext* ctx, Signature* signature); + // The value associated with a cache entry. struct Entry { mutex mu; diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 92784a5358b..e8b1f542ecf 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -16,16 +16,15 @@ limitations under the License. // Registers the XLA_CPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "Host" (CPU) backend. +#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -const char* const DEVICE_XLA_CPU = "XLA_CPU"; - class XlaCpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 7835146a01d..5e336c5287b 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -19,11 +19,10 @@ limitations under the License. #include #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -41,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" @@ -56,10 +56,13 @@ namespace tensorflow { << device_ordinal; // These are no-ops if they have already been done previously for - // this device_name/jit_device_name pair. - XlaOpRegistry::RegisterJitKernels(); - XlaOpRegistry::RegisterJitDevice(device_name, jit_device_name, - /*requires_jit=*/true); + // this device_name/compilation_device_name pair. + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = jit_device_name; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + XlaOpRegistry::RegisterCompilationDevice(device_name, registration); auto platform = perftools::gputools::MultiPlatformManager::PlatformWithName( platform_name); @@ -106,12 +109,23 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { string XlaDevice::Metadata::DebugString() { return "XLA device metadata"; } +/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, + Metadata** metadata) { + ResourceMgr* rm = ctx->resource_manager(); + if (rm == nullptr) { + return errors::Internal("No resource manager."); + } + TF_RETURN_IF_ERROR( + rm->Lookup(rm->default_container(), "xla_metadata", metadata)); + return Status::OK(); +} + XlaDevice::XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, perftools::gputools::Platform* platform, Allocator* xla_allocator) - : LocalDevice(options, attrs, xla_allocator), + : LocalDevice(options, attrs), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(xla_allocator), @@ -161,6 +175,10 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); + // When TraceMe profiling is off (which is the default), the + // following TraceMe constructor is simply a conditional test of + // false value. Measurements show that its overhead is negligible. + port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string()); op_kernel->Compute(context); } @@ -168,6 +186,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); + port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string()); op_kernel->ComputeAsync(context, done); } @@ -203,6 +222,7 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { + XlaOpRegistry::RegisterCompilationKernels(); XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations; auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* { return new XlaDeviceDummyOp(context); diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 3de14f30616..0badb390c6b 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -67,6 +67,10 @@ class XlaDevice : public LocalDevice { perftools::gputools::Platform* platform_; // Not owned. }; + // Sets `*metadata` to the XlaDevice Metadata in the resource manager of + // `ctx`. + static Status GetMetadata(OpKernelContext* ctx, Metadata** metadata); + // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. diff --git a/tensorflow/compiler/jit/xla_device_launch_op.cc b/tensorflow/compiler/jit/xla_device_launch_op.cc deleted file mode 100644 index 1d5d7da14cc..00000000000 --- a/tensorflow/compiler/jit/xla_device_launch_op.cc +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/xla_device_launch_op.h" - -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/xla_compilation_cache.h" -#include "tensorflow/compiler/jit/xla_device.h" -#include "tensorflow/compiler/jit/xla_device_context.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/env.h" - -namespace tensorflow { - -namespace { - -Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** compiler) { - XlaDevice::Metadata* metadata; - Status s = rm->Lookup(rm->default_container(), - "xla_metadata", &metadata); - if (!s.ok()) { - return s; - } - core::ScopedUnref metadata_ref(metadata); - XlaCompiler::Options options; - options.device_type = metadata->jit_device_type(); - options.client = metadata->client(); - options.allow_cpu_custom_calls = false; - options.local_executable_has_hybrid_result = false; - *compiler = new XlaCompilationCache(options); - return Status::OK(); -} - -} // namespace - -XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx) - : OpKernel(ctx) { - const NameAttrList* func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); - function_ = *func; - VLOG(1) << "XlaDeviceLaunch created function=" - << Canonicalize(function_.name(), function_.attr()); - DataTypeVector constant_types; - OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); - num_constant_args_ = constant_types.size(); -} - -void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaDeviceLaunch::Compute " - << Canonicalize(function_.name(), function_.attr()); - // We store information about the JIT-compiled XLA computation - // in the ResourceMgr. - ResourceMgr* rm = ctx->resource_manager(); - OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); - - XlaCompilationCache* compiler; - OP_REQUIRES_OK(ctx, - rm->LookupOrCreate( - rm->default_container(), "xla_compiler", &compiler, - [rm](XlaCompilationCache** compiler) { - return BuildCompilationCache(rm, compiler); - })); - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref compiler_ref(compiler); - - const XlaCompiler::CompilationResult* kernel; - OP_REQUIRES_OK( - ctx, - compiler->Compile(function_, num_constant_args_, ctx, &kernel, nullptr)); - - VLOG(1) << "Executing XLA Computation..."; - - OP_REQUIRES(ctx, ctx->num_outputs() == kernel->outputs.size(), - errors::Internal("Unexpected number of outputs")); - - // Run the computation, if any. There might not be a computation if all - // outputs were compile-time constants. - std::vector> outputs; - if (!kernel->computation.IsNull()) { - auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); - - // Convert argument tensors to xla::GlobalData pointers. - std::vector> arg_handles( - kernel->xla_input_shapes.size()); - std::vector arg_ptrs(kernel->xla_input_shapes.size()); - for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { - int input_num = kernel->xla_input_shapes[i].first; - arg_handles[i] = - XlaTransferManager::GetTensorGlobalData(ctx->input(input_num)); - arg_ptrs[i] = arg_handles[i].get(); - } - - // Execute the computation. - xla::ExecutionProfile profile; - xla::ExecutionOptions execution_options; - *execution_options.mutable_shape_with_output_layout() = - kernel->xla_output_shape; - Env* env = Env::Default(); - auto start_time = env->NowMicros(); - auto result = compiler->client()->Execute(kernel->computation, arg_ptrs, - &execution_options, &profile); - auto elapsed = env->NowMicros() - start_time; - OP_REQUIRES(ctx, result.ok(), result.status()); - - VLOG(1) << "Elapsed time: " << elapsed << "us"; - VLOG(1) << "ExecutionProfile: " << profile.DebugString(); - - if (xla::ShapeUtil::IsTuple(kernel->xla_output_shape)) { - auto outputs_or_error = - compiler->client()->DeconstructTuple(*result.ValueOrDie()); - OP_REQUIRES(ctx, outputs_or_error.ok(), outputs_or_error.status()); - outputs = outputs_or_error.ConsumeValueOrDie(); - } else { - outputs.push_back(result.ConsumeValueOrDie()); - } - } - - XlaDeviceContext* device_context = ctx->op_device_context(); - - // Copy XLA outputs to the operator's outputs. - int output_num = 0; - for (int i = 0; i < ctx->num_outputs(); ++i) { - Tensor* output; - OP_REQUIRES_OK(ctx, - ctx->allocate_output(i, kernel->outputs[i].shape, &output)); - if (kernel->outputs[i].is_constant) { - // TODO(phawkins): mark constant _XlaLaunch outputs as HostMemory and - // remove the copy from this code. - Status status; - device_context->CopyCPUTensorToDevice( - &kernel->outputs[i].constant_value, nullptr, output, - [&status](const Status& s) { status = s; }); - if (!status.ok()) { - ctx->SetStatus(status); - return; - } - } else { - CHECK_LT(output_num, outputs.size()); - XlaTransferManager::SetTensorGlobalData( - std::shared_ptr(std::move(outputs[output_num])), - output); - ++output_num; - } - } - - VLOG(1) << "Done"; -} - -XlaDeviceLaunchOp::~XlaDeviceLaunchOp() { - VLOG(1) << "XlaDeviceLaunch destroyed"; -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_launch_op.h b/tensorflow/compiler/jit/xla_device_launch_op.h deleted file mode 100644 index fbb9319b844..00000000000 --- a/tensorflow/compiler/jit/xla_device_launch_op.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_ -#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_ - -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { - -// The XlaDeviceLaunchOp is used to replace a region of the TensorFlow graph -// which will be compiled and executed using XLA. The XlaDeviceLaunchOp is -// responsible for handling interactions with the TensorFlow executor. -// Once all inputs are present, and their shapes are known, the op can -// use a 'TlaJit' to compile and execute code which is specific -// to the shapes of input Tensors. -class XlaDeviceLaunchOp : public OpKernel { - public: - explicit XlaDeviceLaunchOp(OpKernelConstruction* ctx); - ~XlaDeviceLaunchOp() override; - - void Compute(OpKernelContext* ctx) override; - - private: - NameAttrList function_; - int num_constant_args_; - Tensor dummy_tensor_; - - TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceLaunchOp); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index 0d3a2fa3393..f68dba6b6a2 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -19,13 +19,6 @@ limitations under the License. namespace tensorflow { -void XlaDeviceAssignOp::Copy(OpKernelContext* context, Tensor* lhs, - const Tensor& rhs) { - std::shared_ptr gd = - XlaTransferManager::GetTensorGlobalData(rhs); - XlaTransferManager::SetTensorGlobalData(std::move(gd), lhs); -} - XlaDeviceDummyOp::XlaDeviceDummyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 1fcb515ddb3..8699006ebc5 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -18,9 +18,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ -#include "tensorflow/compiler/jit/xla_device_launch_op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/kernels/assign_op.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" #include "tensorflow/core/kernels/identity_op.h" @@ -30,14 +29,6 @@ limitations under the License. namespace tensorflow { -// Implementation of Assign for XLA devices. -class XlaDeviceAssignOp : public AssignOp { - public: - using AssignOp::AssignOp; - - void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override; -}; - // Dummy OpKernel, used for kernels assigned to an XLA device that should be // compiled. Should never be called at runtime since such ops should be // rewritten to a _XlaLaunch op. If it is called, it means the placer placed an @@ -49,8 +40,11 @@ class XlaDeviceDummyOp : public OpKernel { }; #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER( \ - Name("_XlaLaunch").Device(DEVICE).HostMemory("constants"), KERNEL); + REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \ + .Device(DEVICE) \ + .HostMemory("constants") \ + .HostMemory("resources"), \ + KERNEL); #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ @@ -65,53 +59,13 @@ class XlaDeviceDummyOp : public OpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ - REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), \ - XlaDeviceDummyOp); \ + REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ + REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ + PlaceholderOp); \ \ REGISTER_KERNEL_BUILDER( \ - Name("Variable").Device(DEVICE).TypeConstraint("dtype", TYPES), \ - VariableOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("VariableV2").Device(DEVICE).TypeConstraint("dtype", TYPES), \ - VariableOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("TemporaryVariable").Device(DEVICE).TypeConstraint("dtype", TYPES), \ - TemporaryVariableOp); \ - REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \ - .Device(DEVICE) \ - .TypeConstraint("T", TYPES), \ - DestroyTemporaryVariableOp); \ - REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \ - .Device(DEVICE) \ - .TypeConstraint("dtype", TYPES) \ - .HostMemory("is_initialized"), \ - IsVariableInitializedOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Assign").Device(DEVICE).TypeConstraint("T", TYPES), \ - XlaDeviceAssignOp); \ - \ - REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ - ControlTriggerOp); \ - REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ - REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ - REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ - NextIterationOp); \ - REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ - SwitchOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ - REGISTER_KERNEL_BUILDER(Name("LoopCond") \ - .Device(DEVICE) \ - .HostMemory("input") \ - .HostMemory("output"), \ - IdentityOp); - -// TODO(phawkins): do we really need Placeholder? Should it be a real -// implementation of Placeholder? - -// TODO(b/32507444): the registrations for the control flow operators are -// temporary and exist primarily to work around a bug in the graph partitioning -// code. + Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ + ResourceHandleOp); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index db4c86505cb..872588a24e0 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,16 +16,15 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. +#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -const char* const DEVICE_XLA_GPU = "XLA_GPU"; - class XlaGpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD new file mode 100644 index 00000000000..8c2e9a7c818 --- /dev/null +++ b/tensorflow/compiler/plugin/BUILD @@ -0,0 +1,38 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Configuration file for an XLA plugin. +- please don't check in changes to this file +- to prevent changes appearing in git status, use: + git update-index --assume-unchanged tensorflow/compiler/plugin/BUILD + +To add additional devices to the XLA subsystem, add targets to the +dependency list in the 'plugin' target. For instance: + + deps = ["//tensorflow/compiler/plugin/example:plugin_lib"], +""" + +licenses(["notice"]) + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "plugin", + deps = [ + "//tensorflow/compiler/plugin/executor:plugin_lib", + ], +) diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD new file mode 100644 index 00000000000..9bc706abdf6 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/BUILD @@ -0,0 +1,32 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "plugin_lib", + srcs = glob([ + "*.cc", + ]), + hdrs = glob([ + "*.h", + ]), + deps = [ + "//tensorflow/compiler/jit:xla_jit_headers_lib", + "//tensorflow/compiler/xla:xla_headers_lib", + "//tensorflow/compiler/xla/service:hlo_evaluator", + "//third_party/eigen3", + "@local_config_cuda//cuda:cuda_headers", + "@protobuf//:protobuf_headers", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/plugin/executor/compiler.cc new file mode 100644 index 00000000000..893ff152f0c --- /dev/null +++ b/tensorflow/compiler/plugin/executor/compiler.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/plugin/executor/compiler.h" +#include "tensorflow/compiler/plugin/executor/executable.h" + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" +#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/status_macros.h" + +#include "tensorflow/stream_executor/lib/initialize.h" +#include "tensorflow/stream_executor/lib/strcat.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace se = ::perftools::gputools; +namespace sep = ::perftools::gputools::executorplugin; +namespace port = ::perftools::gputools::port; + +namespace xla { +namespace executorplugin { + +/* + * Run optimization passes on the module. The graph is transformed by + * each pass in the optimization pipeline. The service subdirectory + * contains useful optimization passes. + */ +Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module, + HloDumper dump_hlo) { + HloPassPipeline pipeline("Executor", dump_hlo); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(false); + + pipeline.AddPass>( + false, [](const Shape&, const Shape&) { return false; }); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(true); + + pipeline.AddPass(); + pipeline.AddPass(); + return pipeline.Run(hlo_module).status(); +} + +StatusOr> ExecutorCompiler::Compile( + std::unique_ptr hlo_module, HloDumper dump_hlo, + se::StreamExecutor* stream_exec) { + TF_RET_CHECK(stream_exec != nullptr); + + VLOG(1) << "Generate graph " << hlo_module->name(); + + TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get(), dump_hlo)); + + // Typically you would visit the HLO graph, building up a compiled equivalent + // In this case we are using an Hlo evaluator at execution time, so we don't + // need to compile anything + + // Create executable from only the Hlo module + std::unique_ptr executable; + executable.reset(new ExecutorExecutable(std::move(hlo_module))); + + return std::move(executable); +} + +StatusOr>> ExecutorCompiler::Compile( + std::vector> hlo_modules, + HloDumper dump_hlos, std::vector stream_execs) { + + return tensorflow::errors::Unimplemented( + "Compilation of multiple HLO modules is not supported on Executor."); +} + +StatusOr>> +ExecutorCompiler::CompileAheadOfTime( + std::vector> hlo_modules, + HloDumper dump_hlo, const AotCompilationOptions& aot_options) { + + return tensorflow::errors::InvalidArgument( + "AOT compilation not supported on Executor"); +} + +se::Platform::Id ExecutorCompiler::PlatformId() const { + return sep::kExecutorPlatformId; +} + +HloCostAnalysis::ShapeSizeFunction +ExecutorCompiler::ShapeSizeBytesFunction() const { + return ExecutorExecutable::ShapeSizeBytes; +} + + +} // namespace executorplugin +} // namespace xla + +REGISTER_MODULE_INITIALIZER(executor_compiler, { + xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() { + return xla::MakeUnique(); + }); +}); diff --git a/tensorflow/compiler/plugin/executor/compiler.h b/tensorflow/compiler/plugin/executor/compiler.h new file mode 100644 index 00000000000..8fe591c8abd --- /dev/null +++ b/tensorflow/compiler/plugin/executor/compiler.h @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ + +#include + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +#include "tensorflow/compiler/plugin/executor/platform_id.h" + +namespace xla { +namespace executorplugin { + +class ExecutorCompiler : public Compiler { + public: + ExecutorCompiler() {} + ~ExecutorCompiler() override {} + + StatusOr> Compile( + std::unique_ptr hlo_module, + HloDumper dump_hlo, + perftools::gputools::StreamExecutor* stream_exec) override; + + StatusOr>> Compile( + std::vector> hlo_module, + HloDumper dump_hlo, + std::vector stream_exec) override; + + StatusOr>> + CompileAheadOfTime( + std::vector> module, + HloDumper dump_hlo, const AotCompilationOptions& options) override; + + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; + + perftools::gputools::Platform::Id PlatformId() const override; + + private: + Status RunHloOptimization(HloModule* hlo_module, HloDumper dump_hlo); + + TF_DISALLOW_COPY_AND_ASSIGN(ExecutorCompiler); +}; + +} // namespace executorplugin +} // namespace xla + +#endif // TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ diff --git a/tensorflow/compiler/plugin/executor/device.cc b/tensorflow/compiler/plugin/executor/device.cc new file mode 100644 index 00000000000..bbc39dc03f8 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/device.cc @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_device_ops.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +const char* const DEVICE_XLA_EXEC = "XLA_EXEC"; +const char* const DEVICE_EXEC_XLA_JIT = "XLA_EXEC_JIT"; + +constexpr std::array kExecAllTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}}; + +class XlaExaDeviceFactory : public DeviceFactory { + public: + Status CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector* devices) override; +}; + +Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options, + const string& name_prefix, + std::vector* devices) { + static XlaDeviceOpRegistrations* registrations = + RegisterXlaDeviceKernels(DEVICE_XLA_EXEC, DEVICE_EXEC_XLA_JIT); + (void)registrations; + + std::unique_ptr device; + TF_RETURN_IF_ERROR(XlaDevice::Create("Executor", DEVICE_XLA_EXEC, 0, + DEVICE_EXEC_XLA_JIT, options, + name_prefix, &device)); + devices->push_back(device.release()); + return Status::OK(); +} + +REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 110); + +// Kernel registrations + +static bool OpFilter(KernelDef* kdef) { return true; } + +REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_EXEC, XlaDeviceLaunchOp, kExecAllTypes); +REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_EXEC, kExecAllTypes); +REGISTER_XLA_BACKEND(DEVICE_EXEC_XLA_JIT, kExecAllTypes, OpFilter); + +} // namespace tensorflow diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/plugin/executor/executable.cc new file mode 100644 index 00000000000..92a517ba533 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/executable.cc @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/plugin/executor/executable.h" +#include "tensorflow/compiler/plugin/executor/executor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace se = ::perftools::gputools; +namespace sep = ::perftools::gputools::executorplugin; + +namespace xla { +namespace executorplugin { + +ExecutorExecutable::ExecutorExecutable(std::unique_ptr hlo_module) + : Executable(std::move(hlo_module), ShapeSizeBytes) {} + +ExecutorExecutable::~ExecutorExecutable() {} + +static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor, + const Literal& literal) { + int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); + void* buf = executor->Allocate(size); + const void* src = LiteralUtil::InternalData(literal); + memcpy(buf, src, size); + return se::DeviceMemoryBase(buf, size); +} + +static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExecutorExecutor* executor, + const Literal& literal) { + const Shape& shape = literal.shape(); + if (shape.element_type() != xla::TUPLE) { + return AllocateSingleOutput(executor, literal); + } else { + int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*))); + void** buf = reinterpret_cast(executor->Allocate(size)); + for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) { + se::DeviceMemoryBase out = + AllocateSingleOutput(executor, literal.tuple_literals(n)); + *buf++ = out.opaque(); + } + + return se::DeviceMemoryBase(buf, size); + } +} + +StatusOr ExecutorExecutable::ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + + VLOG(1) << "Execute " << module().name(); + if (VLOG_IS_ON(2)) { + for (const auto& a : arguments) { + VLOG(2) << "-- argument " << a.opaque(); + } + } + + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + + HloComputation* computation = module().entry_computation(); + if (computation->num_parameters() != arguments.size()) { + return tensorflow::errors::Internal( + "Mismatch between argument count and graph parameter count."); + } + + // Create the arguments as an vector of XLA literals + std::vector> arg_literals; + std::vector arg_literals_ptrs; + for (int64 p = 0; p < computation->num_parameters(); p++) { + // Create the input literal for the parameter + HloInstruction* param = computation->parameter_instruction(p); + arg_literals.emplace_back(LiteralUtil::CreateFromShape(param->shape())); + arg_literals_ptrs.push_back(arg_literals.back().get()); + + // Copy in the data from the stream_executor buffers + void* buffer = LiteralUtil::MutableInternalData(arg_literals.back().get()); + memcpy(buffer, arguments[p].opaque(), + ShapeUtil::ByteSizeOf(param->shape())); + } + + // Execute the graph using the evaluator + HloEvaluator evaluator; + std::unique_ptr output; + TF_ASSIGN_OR_RETURN(output, + evaluator.Evaluate(computation, arg_literals_ptrs)); + + // Copy the result into the return buffer + perftools::gputools::StreamExecutor* executor(stream->parent()); + sep::ExecutorExecutor* executorExecutor( + static_cast(executor->implementation())); + + se::DeviceMemoryBase ret = + AllocateOutputBuffer(executorExecutor, *(output.get())); + + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); + + { + tensorflow::mutex_lock lock(mutex_); + const double nanoseconds = (end_micros - start_micros) * 1000.0; + execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); + } + + return ret; +} + +StatusOr> ExecutorExecutable::ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + return tensorflow::errors::Unimplemented( + "ExecuteOnStream is not yet supported on Executor."); +} + +StatusOr ExecutorExecutable::ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments) { + return tensorflow::errors::Unimplemented( + "ExecuteAsyncOnStream is not yet supported on Executor."); +} + +/*static*/ int64 ExecutorExecutable::ShapeSizeBytes(const Shape& shape) { + if (ShapeUtil::IsOpaque(shape)) { + return sizeof(void*); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); +} + + +} // namespace executorplugin +} // namespace xla diff --git a/tensorflow/compiler/plugin/executor/executable.h b/tensorflow/compiler/plugin/executor/executable.h new file mode 100644 index 00000000000..ba3d4da21d0 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/executable.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace executorplugin { + +class ExecutorExecutable : public Executable { + public: + ExecutorExecutable(std::unique_ptr hlo_module); + ~ExecutorExecutable() override; + + StatusOr ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr> ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) override; + + StatusOr ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments) override; + + static int64 ShapeSizeBytes(const Shape& shape); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ExecutorExecutable); +}; + +} // namespace executorplugin +} // namespace xla + +#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ diff --git a/tensorflow/compiler/plugin/executor/executor.cc b/tensorflow/compiler/plugin/executor/executor.cc new file mode 100644 index 00000000000..e72c2711f79 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/executor.cc @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/plugin/executor/executor.h" +#include "tensorflow/compiler/plugin/executor/platform_id.h" + +#include "tensorflow/compiler/xla/status_macros.h" + +#include +#include + +namespace se = ::perftools::gputools; + +namespace perftools { +namespace gputools { +namespace executorplugin { + +host::HostStream *AsExecutorStream(Stream *stream) { + DCHECK(stream != nullptr); + return dynamic_cast(stream->implementation()); +} + +ExecutorExecutor::ExecutorExecutor(const PluginConfig &plugin_config) + : plugin_config_(plugin_config) {} + +ExecutorExecutor::~ExecutorExecutor() {} + +void *ExecutorExecutor::Allocate(uint64 size) { + void *buf = new char[size]; + return buf; +} + +void *ExecutorExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, + uint64 offset_bytes, + uint64 size_bytes) { + return parent + offset_bytes; +} + +void ExecutorExecutor::Deallocate(DeviceMemoryBase *mem) { + if (!mem->is_sub_buffer()) { + delete[] static_cast(mem->opaque()); + } +} + +bool ExecutorExecutor::Memcpy(Stream *stream, void *host_dst, + const DeviceMemoryBase &dev_src, uint64 size) { + AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { + port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); + }); + return true; +} + +bool ExecutorExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, + const void *host_src, uint64 size) { + AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { + port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); + }); + return true; +} + +port::Status ExecutorExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, + const void *host_src, + uint64 size) { + memcpy(dev_dst->opaque(), host_src, size); + return port::Status::OK(); +} + +port::Status ExecutorExecutor::SynchronousMemcpy(void *host_dst, + const DeviceMemoryBase &dev_src, + uint64 size) { + memcpy(host_dst, dev_src.opaque(), size); + return port::Status::OK(); +} + +bool ExecutorExecutor::HostCallback(Stream *stream, + std::function callback) { + AsExecutorStream(stream)->EnqueueTask(callback); + return true; +} + +bool ExecutorExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { + AsExecutorStream(dependent)->EnqueueTask( + [other]() { other->BlockHostUntilDone(); }); + AsExecutorStream(dependent)->BlockUntilDone(); + return true; +} + +bool ExecutorExecutor::StartTimer(Stream *stream, Timer *timer) { + dynamic_cast(timer->implementation())->Start(stream); + return true; +} + +bool ExecutorExecutor::StopTimer(Stream *stream, Timer *timer) { + dynamic_cast(timer->implementation())->Stop(stream); + return true; +} + +bool ExecutorExecutor::BlockHostUntilDone(Stream *stream) { + AsExecutorStream(stream)->BlockUntilDone(); + return true; +} + +DeviceDescription *ExecutorExecutor::PopulateDeviceDescription() const { + internal::DeviceDescriptionBuilder builder; + + builder.set_device_address_bits(64); + + builder.set_name("Executor"); + builder.set_device_vendor("VectorName"); + builder.set_platform_version("1.0"); + builder.set_driver_version("1.0"); + builder.set_runtime_version("1.0"); + builder.set_pci_bus_id("1"); + builder.set_device_memory_size(static_cast(4) * 1024 * 1024 * 1024); + builder.set_clock_rate_ghz(static_cast(CLOCKS_PER_SEC) / 1e9); + + auto built = builder.Build(); + return built.release(); +} + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools diff --git a/tensorflow/compiler/plugin/executor/executor.h b/tensorflow/compiler/plugin/executor/executor.h new file mode 100644 index 00000000000..32fdb157e48 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/executor.h @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Declares the ExecutorExecutor class, which is a CPU-only implementation of +// the StreamExecutor interface. For now, this is used for testing and to +// examine the performance of host-based StreamExecutor code. +#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ + +#include "tensorflow/stream_executor/host/host_stream.h" +#include "tensorflow/stream_executor/host/host_timer.h" + +#include "tensorflow/compiler/xla/shape_util.h" + +#include "tensorflow/stream_executor/blas.h" +#include "tensorflow/stream_executor/lib/error.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/rng.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" + +#include +#include + +namespace perftools { +namespace gputools { +namespace executorplugin { + +using Args = tensorflow::gtl::ArraySlice; + +class ExecutorExecutor : public internal::StreamExecutorInterface { + public: + explicit ExecutorExecutor(const PluginConfig &plugin_config); + ~ExecutorExecutor() override; + + port::Status Init(int device_ordinal, DeviceOptions device_options) override { + return port::Status::OK(); + } + + bool GetKernel(const MultiKernelLoaderSpec &spec, + KernelBase *kernel) override { + return false; + } + bool Launch(Stream *stream, const ThreadDim &thread_dims, + const BlockDim &block_dims, const KernelBase &kernel, + const KernelArgsArrayBase &args) override { + return false; + } + + void *Allocate(uint64 size) override; + void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes, + uint64 size_bytes) override; + void Deallocate(DeviceMemoryBase *mem) override; + + void *HostMemoryAllocate(uint64 size) override { return new char[size]; } + void HostMemoryDeallocate(void *mem) override { + delete[] static_cast(mem); + } + bool HostMemoryRegister(void *mem, uint64 size) override { return true; } + bool HostMemoryUnregister(void *mem) override { return true; } + + bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &pop_src, + uint64 size) override; + bool Memcpy(Stream *stream, DeviceMemoryBase *pop_dst, const void *host_src, + uint64 size) override; + bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, + const DeviceMemoryBase &host_src, + uint64 size) override { + return false; + } + + bool MemZero(Stream *stream, DeviceMemoryBase *location, + uint64 size) override { + return false; + } + bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern, + uint64 size) override { + return false; + } + bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern, + uint64 size) override { + return false; + } + + // No "synchronize all activity" implemented for this platform at the moment. + bool SynchronizeAllActivity() override { return false; } + bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override { + return false; + } + + bool SynchronousMemSet(DeviceMemoryBase *location, int value, + uint64 size) override { + return false; + } + + port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst, + const void *host_src, uint64 size) override; + port::Status SynchronousMemcpy(void *host_dst, + const DeviceMemoryBase &pop_src, + uint64 size) override; + port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, + const DeviceMemoryBase &pop_src, + uint64 size) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + bool HostCallback(Stream *stream, std::function callback) override; + + port::Status AllocateEvent(Event *event) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + port::Status DeallocateEvent(Event *event) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + port::Status RecordEvent(Stream *stream, Event *event) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + port::Status WaitForEvent(Stream *stream, Event *event) override { + return port::Status{port::error::UNIMPLEMENTED, ""}; + } + + Event::Status PollForEventStatus(Event *event) override { + return Event::Status::kError; + } + + bool AllocateStream(Stream *stream) override { return true; } + void DeallocateStream(Stream *stream) override {} + bool CreateStreamDependency(Stream *dependent, Stream *other) override; + + bool AllocateTimer(Timer *timer) override { return true; } + void DeallocateTimer(Timer *timer) override {} + bool StartTimer(Stream *stream, Timer *timer) override; + bool StopTimer(Stream *stream, Timer *timer) override; + + bool BlockHostUntilDone(Stream *stream) override; + + int PlatformDeviceCount() override { return 1; } + + bool DeviceMemoryUsage(int64 *free, int64 *total) const override { + return false; + } + + DeviceDescription *PopulateDeviceDescription() const override; + + port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { + return port::Status::OK(); + } + + bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override { + return true; + } + + SharedMemoryConfig GetDeviceSharedMemoryConfig() override { + return SharedMemoryConfig::kDefault; + } + + port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override { + return port::Status{port::error::UNIMPLEMENTED, + "Shared memory not supported"}; + } + + std::unique_ptr CreateEventImplementation() + override { + return nullptr; + } + + std::unique_ptr CreateKernelImplementation() + override { + return nullptr; + } + + std::unique_ptr GetStreamImplementation() + override { + return std::unique_ptr(new host::HostStream()); + } + + std::unique_ptr GetTimerImplementation() override { + return std::unique_ptr(new host::HostTimer()); + } + + port::StatusOr ExecuteGraph(const xla::Shape &shape, + Args args); + + private: + DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); + + port::StatusOr AllocateOutputBuffer( + const xla::Shape &shape); + + const PluginConfig plugin_config_; +}; + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ diff --git a/tensorflow/compiler/plugin/executor/platform.cc b/tensorflow/compiler/plugin/executor/platform.cc new file mode 100644 index 00000000000..2f339f04a7b --- /dev/null +++ b/tensorflow/compiler/plugin/executor/platform.cc @@ -0,0 +1,125 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/plugin/executor/platform.h" +#include "tensorflow/compiler/plugin/executor/executor.h" +#include "tensorflow/compiler/plugin/executor/platform_id.h" + +#include "tensorflow/stream_executor/lib/error.h" +#include "tensorflow/stream_executor/lib/initialize.h" +#include "tensorflow/stream_executor/lib/ptr_util.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/status_macros.h" +#include "tensorflow/stream_executor/lib/stringprintf.h" + +namespace se = ::perftools::gputools; +namespace sep = ::perftools::gputools::executorplugin; + +namespace perftools { +namespace gputools { +namespace executorplugin { + +PLATFORM_DEFINE_ID(kExecutorPlatformId); + +ExecutorPlatform::ExecutorPlatform() : name_("Executor") {} + +ExecutorPlatform::~ExecutorPlatform() {} + +Platform::Id ExecutorPlatform::id() const { return kExecutorPlatformId; } + +int ExecutorPlatform::VisibleDeviceCount() const { return 1; } + +const string& ExecutorPlatform::Name() const { return name_; } + +port::StatusOr ExecutorPlatform::ExecutorForDevice( + int ordinal) { + StreamExecutorConfig config; + config.ordinal = ordinal; + config.plugin_config = PluginConfig(); + config.device_options = DeviceOptions::Default(); + return GetExecutor(config); +} + +port::StatusOr +ExecutorPlatform::ExecutorForDeviceWithPluginConfig( + int device_ordinal, const PluginConfig& plugin_config) { + StreamExecutorConfig config; + config.ordinal = device_ordinal; + config.plugin_config = plugin_config; + config.device_options = DeviceOptions::Default(); + return GetExecutor(config); +} + +port::StatusOr ExecutorPlatform::GetExecutor( + const StreamExecutorConfig& config) { + mutex_lock lock(executors_mutex_); + + port::StatusOr status = executor_cache_.Get(config); + if (status.ok()) { + return status.ValueOrDie(); + } + + port::StatusOr> executor = + GetUncachedExecutor(config); + if (!executor.ok()) { + return executor.status(); + } + + StreamExecutor* naked_executor = executor.ValueOrDie().get(); + SE_RETURN_IF_ERROR( + executor_cache_.Insert(config, executor.ConsumeValueOrDie())); + return naked_executor; +} + +port::StatusOr> +ExecutorPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { + auto executor = port::MakeUnique( + this, port::MakeUnique(config.plugin_config)); + auto init_status = executor->Init(config.ordinal, config.device_options); + if (!init_status.ok()) { + return port::Status{ + port::error::INTERNAL, + port::Printf( + "failed initializing StreamExecutor for device ordinal %d: %s", + config.ordinal, init_status.ToString().c_str())}; + } + + return std::move(executor); +} + +void ExecutorPlatform::RegisterTraceListener( + std::unique_ptr listener) { + LOG(FATAL) << "not yet implemented: register executor trace listener"; +} + +void ExecutorPlatform::UnregisterTraceListener(TraceListener* listener) { + LOG(FATAL) << "not yet implemented: unregister executor trace listener"; +} + +static void InitializeExecutorPlatform() { + std::unique_ptr platform(new sep::ExecutorPlatform); + SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform))); +} + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools + +REGISTER_MODULE_INITIALIZER(executor_platform, sep::InitializeExecutorPlatform()); + +DECLARE_MODULE_INITIALIZER(multi_platform_manager); +// Note that module initialization sequencing is not supported in the +// open-source project, so this will be a no-op there. +REGISTER_MODULE_INITIALIZER_SEQUENCE(executor_platform, multi_platform_manager); diff --git a/tensorflow/compiler/plugin/executor/platform.h b/tensorflow/compiler/plugin/executor/platform.h new file mode 100644 index 00000000000..c252a589d49 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/platform.h @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ + +#include +#include +#include + +#include "tensorflow/stream_executor/executor_cache.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/platform/mutex.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/platform/thread_annotations.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/stream_executor/trace_listener.h" + +namespace perftools { +namespace gputools { +namespace executorplugin { + +class ExecutorPlatform : public Platform { + public: + ExecutorPlatform(); + ~ExecutorPlatform() override; + + Platform::Id id() const override; + + // Device count is less clear-cut for CPUs than accelerators. This call + // currently returns the number of thread units in the host, as reported by + // base::NumCPUs(). + int VisibleDeviceCount() const override; + + const string& Name() const override; + + port::StatusOr ExecutorForDevice(int ordinal) override; + + port::StatusOr ExecutorForDeviceWithPluginConfig( + int ordinal, const PluginConfig& config) override; + + port::StatusOr GetExecutor( + const StreamExecutorConfig& config) override; + + port::StatusOr> GetUncachedExecutor( + const StreamExecutorConfig& config) override; + + void RegisterTraceListener(std::unique_ptr listener) override; + + void UnregisterTraceListener(TraceListener* listener) override; + + private: + // This platform's name. + string name_; + + // mutex that guards the ordinal-to-executor map. + mutable mutex executors_mutex_; + + // Cache of created StreamExecutors. + ExecutorCache executor_cache_; + + SE_DISALLOW_COPY_AND_ASSIGN(ExecutorPlatform); +}; + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ diff --git a/tensorflow/compiler/plugin/executor/platform_id.h b/tensorflow/compiler/plugin/executor/platform_id.h new file mode 100644 index 00000000000..8d2b29a3e4e --- /dev/null +++ b/tensorflow/compiler/plugin/executor/platform_id.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ +#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ + +#include "tensorflow/stream_executor/platform.h" + +namespace perftools { +namespace gputools { +namespace executorplugin { + +extern const Platform::Id kExecutorPlatformId; + +} // namespace executorplugin +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc new file mode 100644 index 00000000000..b59d20a7791 --- /dev/null +++ b/tensorflow/compiler/plugin/executor/transfer_manager.cc @@ -0,0 +1,182 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/plugin/executor/transfer_manager.h" +#include "tensorflow/compiler/plugin/executor/platform_id.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +#include +#include +#include + +namespace sep = ::perftools::gputools::executorplugin; + +namespace xla { +namespace executorplugin { + +ExecutorTransferManager::ExecutorTransferManager() {} + +se::Platform::Id ExecutorTransferManager::PlatformId() const { + return se::executorplugin::kExecutorPlatformId; +} + +Status ExecutorTransferManager::TransferLiteralFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& device_shape, const Shape& literal_shape, Literal* literal) { + TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape)); + + // Tuples are a special case and contain one or more shapes inside of them to + // an arbitrary nesting depth. + if (device_shape.element_type() == TUPLE) { + *literal->mutable_shape() = literal_shape; + TF_ASSIGN_OR_RETURN( + std::vector element_buffers, + ShallowCopyTupleFromDevice(executor, source, device_shape)); + TF_RET_CHECK(element_buffers.size() == + ShapeUtil::TupleElementCount(device_shape)); + for (int64 i = 0; i < element_buffers.size(); ++i) { + const Shape& element_device_shape = device_shape.tuple_shapes(i); + const Shape& element_literal_shape = literal_shape.tuple_shapes(i); + Literal* element_literal = literal->add_tuple_literals(); + // Recursively call TransferFromDevice to copy over the data in the + // element array. + TF_RETURN_IF_ERROR(TransferLiteralFromDevice( + executor, element_buffers[i], element_device_shape, + element_literal_shape, element_literal)); + } + return Status::OK(); + } + + *literal->mutable_shape() = device_shape; + LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal); + TF_RETURN_IF_ERROR(TransferBufferFromDevice( + executor, source, ShapeUtil::ByteSizeOf(device_shape), + LiteralUtil::MutableInternalData(literal))); + if (!ShapeUtil::Equal(literal_shape, device_shape)) { + literal->Swap( + LiteralUtil::Relayout(*literal, literal_shape.layout()).get()); + } + TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); + return Status::OK(); +} + +StatusOr> +ExecutorTransferManager::ShallowCopyTupleFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsTuple(shape)); + + std::vector element_pointers(ShapeUtil::TupleElementCount(shape), + nullptr); + int64 tuple_size = ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, + element_pointers.data()); + if (!copy_status.ok()) { + return AddStatus( + Status(static_cast(copy_status.code()), + copy_status.error_message()), + "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); + } + + // Create a DeviceMemoryBase from each void* pointer. + std::vector destination; + for (int i = 0; i < element_pointers.size(); ++i) { + if (element_pointers[i] == nullptr && + !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { + return FailedPrecondition("tuple contains nullptr at element %d", i); + } + int64 buffer_size = + ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), sizeof(void*)); + destination.emplace_back(element_pointers[i], buffer_size); + } + return std::move(destination); +} + +Status ExecutorTransferManager::TransferLiteralToDevice( + se::StreamExecutor* executor, const Literal& literal, + se::DeviceMemoryBase* destination) { + const Shape& shape = literal.shape(); + + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector tuple_elements_on_device; + for (const Literal& tuple_element : literal.tuple_literals()) { + se::DeviceMemoryBase allocation = executor->AllocateArray( + GetByteSizeRequirement(tuple_element.shape())); + TF_RETURN_IF_ERROR( + TransferLiteralToDevice(executor, tuple_element, &allocation)); + tuple_elements_on_device.push_back(allocation.opaque()); + } + return TransferBufferToDevice( + executor, tuple_elements_on_device.size() * sizeof(void*), + tuple_elements_on_device.data(), destination); + } + + return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), + LiteralUtil::InternalData(literal), + destination); +} + +Status ExecutorTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const Literal& literal) { + const Shape& shape = literal.shape(); + VLOG(1) << "transferring literal shape to infeed: " + << ShapeUtil::HumanString(shape); + + return Status::OK(); +} + +Status ExecutorTransferManager::TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) { + const Shape& shape = literal->shape(); + VLOG(1) << "transferring literal shape from outfeed: " + << ShapeUtil::HumanString(shape); + + return Status::OK(); +} + +Status ExecutorTransferManager::ResetDevices( + tensorflow::gtl::ArraySlice + executors) { + return Unimplemented("Device reset not supported"); +} + +int64 ExecutorTransferManager::GetByteSizeRequirement(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); +} + +} // namespace executorplugin +} // namespace xla + +static std::unique_ptr CreateExecutorTransferManager() { + return xla::MakeUnique(); +} + +static bool InitModule() { + xla::TransferManager::RegisterTransferManager(sep::kExecutorPlatformId, + &CreateExecutorTransferManager); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.h b/tensorflow/compiler/plugin/executor/transfer_manager.h new file mode 100644 index 00000000000..22142cd778a --- /dev/null +++ b/tensorflow/compiler/plugin/executor/transfer_manager.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_ + +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +#include + +namespace se = ::perftools::gputools; + +namespace xla { +namespace executorplugin { + +class ExecutorTransferManager : public TransferManager { + public: + ExecutorTransferManager(); + + ~ExecutorTransferManager() override {} + + se::Platform::Id PlatformId() const override; + + StatusOr> ShallowCopyTupleFromDevice( + se::StreamExecutor* executor, const se::DeviceMemoryBase& source, + const Shape& shape) override; + + Status TransferLiteralFromDevice(se::StreamExecutor* executor, + const se::DeviceMemoryBase& source, + const Shape& device_shape, + const Shape& literal_shape, + Literal* literal) override; + + Status TransferLiteralToDevice(se::StreamExecutor* executor, + const Literal& literal, + se::DeviceMemoryBase* destination) override; + + Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const Literal& literal) override; + + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + Literal* literal) override; + + Status ResetDevices( + tensorflow::gtl::ArraySlice executors) override; + + int64 GetByteSizeRequirement(const Shape& shape) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ExecutorTransferManager); +}; + +} // namespace executorplugin +} // namespace xla + +#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index a3c634c1abf..4bbb2767ac0 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1,11 +1,23 @@ licenses(["notice"]) # Apache 2.0 -package( - default_visibility = [ +package_group( + name = "internal", + includes = [ "//tensorflow/compiler/tf2xla:internal", ], ) +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/tf2xla:friends", + ], +) + +package( + default_visibility = [":internal"], +) + load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") @@ -19,6 +31,7 @@ py_library( testonly = 1, srcs = ["xla_test.py"], srcs_version = "PY2AND3", + visibility = [":friends"], deps = [ "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/core:protos_all_py", @@ -38,6 +51,34 @@ cc_library( deps = ["//tensorflow/core:framework_lite"], ) +tf_xla_py_test( + name = "adagrad_test", + size = "small", + srcs = ["adagrad_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "adam_test", + size = "small", + srcs = ["adam_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "binary_ops_test", size = "small", @@ -100,6 +141,22 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "conv3d_test", + size = "medium", + srcs = ["conv3d_test.py"], + shard_count = 5, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "dynamic_stitch_test", size = "small", @@ -113,6 +170,33 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "slice_ops_test", + size = "small", + srcs = ["slice_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "ftrl_test", + size = "small", + srcs = ["ftrl_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "function_test", size = "small", @@ -139,6 +223,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "momentum_test", + size = "small", + srcs = ["momentum_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "nary_ops_test", size = "small", @@ -179,6 +277,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "pooling_ops_3d_test", + size = "medium", + srcs = ["pooling_ops_3d_test.py"], + shard_count = 10, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "random_ops_test", size = "small", @@ -208,6 +321,64 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "reverse_ops_test", + size = "small", + srcs = ["reverse_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + +tf_xla_py_test( + name = "rmsprop_test", + size = "small", + srcs = ["rmsprop_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "spacetobatch_op_test", + size = "medium", + srcs = ["spacetobatch_op_test.py"], + shard_count = 3, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "tensor_array_ops_test", + size = "small", + srcs = ["tensor_array_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:math_ops_gen", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + "//tensorflow/python:tensor_array_grad", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "ternary_ops_test", size = "small", @@ -236,6 +407,23 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "variable_ops_test", + size = "small", + srcs = ["variable_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:math_ops_gen", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + cuda_py_test( name = "xla_device_test", size = "small", @@ -294,7 +482,6 @@ tf_cuda_cc_test( # This test is randomized, so only run it if explicitly requested. tags = [ "manual", - "noguitar", "notap", ], deps = [":randomized_tests_library"], @@ -336,8 +523,12 @@ cuda_py_test( # --dump_graph_dir, and the config file was written by hand. # # Run the following to build a minimal benchmark of the computation on Android: -# $ bazel build -c opt --config=android_arm \ -# third_party/tensorflow/compiler/tests:lstm_layer_inference_benchmark +# $ bazel build -c opt --cxxopt='-std=c++11' --linkopt='-lm' \ +# --cpu=armeabi-v7a \ +# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ +# --crosstool_top=//external:android/crosstool \ +# //tensorflow/compiler/tests:lstm_layer_inference_benchmark + # # Currently the resulting binary size is ~190KB tf_library( diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py new file mode 100644 index 00000000000..a5c5885b428 --- /dev/null +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -0,0 +1,116 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adagrad.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adagrad + + +class AdagradOptimizerTest(XLATestCase): + + def testBasic(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + ada_opt = adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1) + ada_update = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 3 steps of adagrad + for _ in range(3): + ada_update.run() + # Validate updated params + self.assertAllCloseAccordingToType( + np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + + def testTensorLearningRate(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + ada_opt = adagrad.AdagradOptimizer( + constant_op.constant(3.0), initial_accumulator_value=0.1) + ada_update = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 3 steps of adagrad + for _ in range(3): + ada_update.run() + # Validate updated params + self.assertAllCloseAccordingToType( + np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + + def testSharing(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + ada_opt = adagrad.AdagradOptimizer(3.0) + # Apply the optimizer twice. Both applications will use + # the same accums. + ada_update1 = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + ada_update2 = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.assertEqual(["accumulator"], ada_opt.get_slot_names()) + slot0 = ada_opt.get_slot(var0, "accumulator") + self.assertEquals(slot0.get_shape(), var0.get_shape()) + slot1 = ada_opt.get_slot(var1, "accumulator") + self.assertEquals(slot1.get_shape(), var1.get_shape()) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values. + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Mix the first and the second adagrad for 3 steps. + ada_update1.run() + ada_update2.run() + ada_update1.run() + # Validate updated params (the same as with only 1 Adagrad). + self.assertAllCloseAccordingToType( + np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py new file mode 100644 index 00000000000..3215dc36e5b --- /dev/null +++ b/tensorflow/compiler/tests/adam_test.py @@ -0,0 +1,176 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adam.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamOptimizerTest(XLATestCase): + + def testBasic(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = array_ops.placeholder(dtype) + grads1 = array_ops.placeholder(dtype) + opt = adam.AdamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTensorLearningRate(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = array_ops.placeholder(dtype) + grads1 = array_ops.placeholder(dtype) + opt = adam.AdamOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = array_ops.placeholder(dtype) + grads1 = array_ops.placeholder(dtype) + opt = adam.AdamOptimizer() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + else: + update2.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 852c80db1fe..7221a0a3c74 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -107,6 +107,12 @@ class BinaryOpsTest(XLATestCase): np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-75, -48, -21, 0], dtype=dtype)) + self._testBinary( + gen_nn_ops._elu_grad, + np.array([1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), + expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) + self._testBinary( gen_nn_ops._relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), @@ -132,6 +138,20 @@ class BinaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) + self._testBinary( + gen_nn_ops._sparse_softmax_cross_entropy_with_logits, + np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2]], dtype=dtype), + np.array([2, 1, 7], dtype=np.int32), + expected=[ + np.array([1.342536, 1.442536, np.nan], dtype=dtype), + np.array([[0.213838, 0.236328, -0.738817, 0.288651], + [0.213838, -0.763672, 0.261183, 0.288651], + [np.nan, np.nan, np.nan, np.nan]], + dtype=dtype), + ], + equality_test=self.ListsAreClose) + def testIntOps(self): for dtype in self.int_types: self._testBinary( diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 820db13d0b1..0bde616521a 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -1,12 +1,14 @@ """Build rules for Tensorflow/XLA testing.""" load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") +load("//tensorflow/compiler/tests:plugin.bzl", "plugins") def all_backends(): + b = ["cpu"] + plugins.keys() if cuda_is_configured(): - return ["cpu", "gpu"] + return b + ["gpu"] else: - return ["cpu"] + return b def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, disabled_backends=None, **kwargs): @@ -53,6 +55,10 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, backend_args += ["--test_device=XLA_GPU", "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] backend_tags += ["requires-gpu-sm35"] + elif backend in plugins: + backend_args += ["--test_device=" + plugins[backend]["device"], + "--types=" + plugins[backend]["types"]] + backend_tags += plugins[backend]["tags"] else: fail("Unknown backend {}".format(backend)) diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index 01cfbd9f7c0..4bc118b5bdb 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -310,7 +310,7 @@ class Conv2DBackpropFilterTest(XLATestCase): data_format="NHWC") value = sess.run(tensor, {t1: x1, t2: x2}) - self.assertArrayNear(expected, np.ravel(value), 1e-5) + self.assertArrayNear(expected, np.ravel(value), 1e-3) def testConv2D1x1Filter(self): expected_output = [8056, 8432, 8312, 8704, 8568, 8976] diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py new file mode 100644 index 00000000000..3bebf46511c --- /dev/null +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -0,0 +1,233 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for 3D convolutions using the XLA JIT.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import nn_ops +import tensorflow.python.ops.nn_grad # pylint: disable=unused-import +from tensorflow.python.platform import googletest + + +# Test cloned from +# tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py +class Conv3DBackpropFilterV2GradTest(XLATestCase): + + def testGradient(self): + with self.test_session(), self.test_scope(): + for padding in ["SAME", "VALID"]: + for stride in [1, 2]: + np.random.seed(1) + in_shape = [2, 4, 3, 3, 2] + in_val = constant_op.constant( + 2 * np.random.random_sample(in_shape) - 1, dtype=dtypes.float32) + filter_shape = [3, 3, 3, 2, 3] + strides = [1, stride, stride, stride, 1] + # Make a convolution op with the current settings, just to easily get + # the shape of the output. + conv_out = nn_ops.conv3d(in_val, + array_ops.zeros(filter_shape), strides, + padding) + out_backprop_shape = conv_out.get_shape().as_list() + out_backprop_val = constant_op.constant( + 2 * np.random.random_sample(out_backprop_shape) - 1, + dtype=dtypes.float32) + output = nn_ops.conv3d_backprop_filter_v2(in_val, filter_shape, + out_backprop_val, strides, + padding) + err = gradient_checker.compute_gradient_error( + [in_val, out_backprop_val], [in_shape, out_backprop_shape], + output, filter_shape) + print("conv3d_backprop_filter gradient err = %g " % err) + err_tolerance = 1e-3 + self.assertLess(err, err_tolerance) + + +# Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py +class Conv3DTransposeTest(XLATestCase): + + def testConv3DTransposeSingleStride(self): + with self.test_session(), self.test_scope(): + strides = [1, 1, 1, 1, 1] + + # Input, output: [batch, depth, height, width, channel] + x_shape = [2, 5, 6, 4, 3] + y_shape = [2, 5, 6, 4, 2] + + # Filter: [kernel_depth, kernel_height, kernel_width, out_depth, in_depth] + f_shape = [3, 3, 3, 2, 3] + + x = constant_op.constant( + 1.0, shape=x_shape, name="x", dtype=dtypes.float32) + f = constant_op.constant( + 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) + output = nn_ops.conv3d_transpose( + x, f, y_shape, strides=strides, padding="SAME") + value = output.eval() + + # We count the number of cells being added at the locations in the output. + # At the center, #cells = kernel_depth * kernel_height * kernel_width + # At the corners, #cells = ceil(kernel_depth/2) * ceil(kernel_height/2) + # * ceil(kernel_width/2) + # At the edges, #cells = + # kernel_depth * ceil(kernel_height/2) * ceil(kernel_width/2) or + # ceil(kernel_depth/2) * kernel_height * ceil(kernel_width/2) or + # ceil(kernel_depth/2) * ceil(kernel_height/2) * kernel_width + # At the borders, #cells = + # ceil(kernel_depth/2) * kernel_height * kernel_width or + # kernel_depth * ceil(kernel_height/2) * kernel_width or + # kernel_depth * kernel_height * ceil(kernel_width/2) + + for n in xrange(x_shape[0]): + for k in xrange(f_shape[3]): + for w in xrange(y_shape[3]): + for h in xrange(y_shape[2]): + for d in xrange(y_shape[1]): + d_in = d > 0 and d < y_shape[1] - 1 + h_in = h > 0 and h < y_shape[2] - 1 + w_in = w > 0 and w < y_shape[3] - 1 + if d_in + h_in + w_in == 3: + target = 27 * 3.0 + elif d_in + h_in + w_in == 2: + target = 18 * 3.0 + elif d_in or h_in or w_in: + target = 12 * 3.0 + else: + target = 8 * 3.0 + self.assertAllClose(target, value[n, d, h, w, k]) + + def testConv3DTransposeSame(self): + with self.test_session(), self.test_scope(): + strides = [1, 2, 2, 2, 1] + + # Input, output: [batch, depth, height, width, depth] + x_shape = [2, 5, 6, 4, 3] + y_shape = [2, 10, 12, 8, 2] + + # Filter: [kernel_depth, kernel_height, kernel_width, out_depth, in_depth] + f_shape = [3, 3, 3, 2, 3] + + x = constant_op.constant( + 1.0, shape=x_shape, name="x", dtype=dtypes.float32) + f = constant_op.constant( + 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) + output = nn_ops.conv3d_transpose( + x, f, y_shape, strides=strides, padding="SAME") + value = output.eval() + + for n in xrange(x_shape[0]): + for k in xrange(f_shape[3]): + for w in xrange(y_shape[3]): + for h in xrange(y_shape[2]): + for d in xrange(y_shape[1]): + # We add a case for locations divisible by the stride. + d_in = d % strides[1] == 0 and 0 < d < y_shape[1] - 1 + h_in = h % strides[2] == 0 and 0 < h < y_shape[2] - 1 + w_in = w % strides[3] == 0 and 0 < w < y_shape[3] - 1 + if d_in + h_in + w_in == 3: + target = 8 * 3.0 + elif d_in + h_in + w_in == 2: + target = 4 * 3.0 + elif d_in or h_in or w_in: + target = 2 * 3.0 + else: + target = 3.0 + self.assertAllClose(target, value[n, d, h, w, k]) + + def testConv3DTransposeValid(self): + with self.test_session(), self.test_scope(): + strides = [1, 2, 2, 2, 1] + + # Input, output: [batch, depth, height, width, depth] + x_shape = [2, 5, 6, 4, 3] + y_shape = [2, 11, 13, 9, 2] + + # Filter: [kernel_depth, kernel_height, kernel_width, out_depth, in_depth] + f_shape = [3, 3, 3, 2, 3] + + x = constant_op.constant( + 1.0, shape=x_shape, name="x", dtype=dtypes.float32) + f = constant_op.constant( + 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) + output = nn_ops.conv3d_transpose( + x, f, y_shape, strides=strides, padding="VALID") + value = output.eval() + + cache_values = np.zeros(y_shape, dtype=np.float32) + + # The amount of padding added + pad = 1 + + for n in xrange(x_shape[0]): + for k in xrange(f_shape[3]): + for w in xrange(y_shape[3]): + for h in xrange(y_shape[2]): + for d in xrange(y_shape[1]): + # We add a case for locations divisible by the stride. + d_in = d % strides[1] == 0 and pad < d < y_shape[1] - 1 - pad + h_in = h % strides[2] == 0 and pad < h < y_shape[2] - 1 - pad + w_in = w % strides[3] == 0 and pad < w < y_shape[3] - 1 - pad + if d_in + h_in + w_in == 3: + target = 8 * 3.0 + elif d_in + h_in + w_in == 2: + target = 4 * 3.0 + elif d_in or h_in or w_in: + target = 2 * 3.0 + else: + target = 3.0 + cache_values[n, d, h, w, k] = target + + # copy values in the border + cache_values[n, :, :, 0, k] = cache_values[n, :, :, 1, k] + cache_values[n, :, :, -1, k] = cache_values[n, :, :, -2, k] + cache_values[n, :, 0, :, k] = cache_values[n, :, 1, :, k] + cache_values[n, :, -1, :, k] = cache_values[n, :, -2, :, k] + cache_values[n, 0, :, :, k] = cache_values[n, 1, :, :, k] + cache_values[n, -1, :, :, k] = cache_values[n, -2, :, :, k] + + self.assertAllClose(cache_values, value) + + def testGradient(self): + x_shape = [2, 3, 4, 3, 2] + f_shape = [3, 3, 3, 2, 2] + y_shape = [2, 6, 8, 6, 2] + strides = [1, 2, 2, 2, 1] + np.random.seed(1) # Make it reproducible. + x_val = np.random.random_sample(x_shape).astype(np.float64) + f_val = np.random.random_sample(f_shape).astype(np.float64) + with self.test_session(), self.test_scope(): + x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) + f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) + output = nn_ops.conv3d_transpose( + x, f, y_shape, strides=strides, padding="SAME") + err = gradient_checker.compute_gradient_error([x, f], [x_shape, f_shape], + output, y_shape) + print("conv3d_transpose gradient err = %g " % err) + err_tolerance = 0.0005 + self.assertLess(err, err_tolerance) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py new file mode 100644 index 00000000000..6b328fb618b --- /dev/null +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -0,0 +1,253 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Ftrl optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adagrad +from tensorflow.python.training import ftrl +from tensorflow.python.training import gradient_descent + + +class FtrlOptimizerTest(XLATestCase): + + def initVariableAndGradient(self, dtype): + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.02, 0.04], dtype=dtype) + + return var0, var1, grads0, grads1 + + def equivAdagradTest_FtrlPart(self, steps, dtype): + var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) + opt = ftrl.FtrlOptimizer( + 3.0, + learning_rate_power=-0.5, # using Adagrad learning rate + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run Ftrl for a few steps + for _ in range(steps): + ftrl_update.run() + + return var0.eval(), var1.eval() + + def equivAdagradTest_AdagradPart(self, steps, dtype): + var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) + opt = adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1) + adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run Adagrad for a few steps + for _ in range(steps): + adagrad_update.run() + + return var0.eval(), var1.eval() + + def equivGradientDescentTest_FtrlPart(self, steps, dtype): + var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) + opt = ftrl.FtrlOptimizer( + 3.0, + learning_rate_power=-0.0, # using Fixed learning rate + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run Ftrl for a few steps + for _ in range(steps): + ftrl_update.run() + + return var0.eval(), var1.eval() + + def equivGradientDescentTest_GradientDescentPart(self, steps, dtype): + var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) + opt = gradient_descent.GradientDescentOptimizer(3.0, name="sgd") + sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run GradientDescent for a few steps + for _ in range(steps): + sgd_update.run() + + return var0.eval(), var1.eval() + + def testFtrlwithoutRegularization(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + opt = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run 3 steps FTRL + for _ in range(3): + ftrl_update.run() + + # Validate updated params + self.assertAllCloseAccordingToType( + np.array([-2.60260963, -4.29698515]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([-0.28432083, -0.56694895]), var1.eval()) + + def testFtrlwithoutRegularization2(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + opt = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 3 steps FTRL + for _ in range(3): + ftrl_update.run() + + # Validate updated params + self.assertAllClose( + np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5) + self.assertAllClose( + np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) + + def testFtrlWithL1(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + opt = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.0) + ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps FTRL + for _ in range(10): + ftrl_update.run() + + # Validate updated params + self.assertAllClose(np.array([-7.66718769, -10.91273689]), var0.eval()) + self.assertAllClose(np.array([-0.93460727, -1.86147261]), var1.eval()) + + def testFtrlWithL1_L2(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + opt = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps FTRL + for _ in range(10): + ftrl_update.run() + + # Validate updated params + self.assertAllClose(np.array([-0.24059935, -0.46829352]), var0.eval()) + self.assertAllClose(np.array([-0.02406147, -0.04830509]), var1.eval()) + + # When variables are intialized with Zero, FTRL-Proximal has two properties: + # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical + # with GradientDescent. + # 2. Without L1&L2 but with adaptive learning rate, FTRL-Proximal is idential + # with Adagrad. + # So, basing on these two properties, we test if our implementation of + # FTRL-Proximal performs same updates as Adagrad or GradientDescent. + def testEquivAdagradwithoutRegularization(self): + steps = 5 + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype) + with self.test_session(), self.test_scope(): + val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) + + self.assertAllClose(val0, val2) + self.assertAllClose(val1, val3) + + def testEquivGradientDescentwithoutRegularization(self): + steps = 5 + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype) + with self.test_session(), self.test_scope(): + val2, val3 = self.equivGradientDescentTest_GradientDescentPart( + steps, dtype) + + self.assertAllClose(val0, val2) + self.assertAllClose(val1, val3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 40cc7a5d600..cbe2888696c 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -103,7 +103,8 @@ class FunctionTest(XLATestCase): result = sess.run(call_f) self.assertAllClose(result, expected, rtol=1e-3) - def testFunctionsNoInline(self): + # TODO(b/36139787): Re-enable this test when noinline works again. + def DISABLED_testFunctionsNoInline(self): @function.Defun(dtypes.float32, noinline=True) def TimesTwo(x): diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 8a568d6d58d..11914080ecc 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -160,12 +160,14 @@ class JitLaunchTest(test.TestCase): # function (say, Bar) which is not inlined. When the compiler compiles # Foo, it needs to symbolic execute Bar correctly regardless whether # Bar is inlined or not. - # + + # TODO(b/36139787): Re-enable this test when noinline works again. # Tests compiled=True and noinline=True. - self._compare( - AddOnceReturnTwice, [np.array( - [[[0.5, -1.0]]], dtype=np.float32)], - noinline=True) + # self._compare( + # AddOnceReturnTwice, [np.array( + # [[[0.5, -1.0]]], dtype=np.float32)], + # noinline=True) + # Tests compiled=True and noinline=False. self._compare( AddOnceReturnTwice, [np.array( diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py new file mode 100644 index 00000000000..c00e3035a09 --- /dev/null +++ b/tensorflow/compiler/tests/momentum_test.py @@ -0,0 +1,179 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Momentum.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import momentum as momentum_lib + + +class MomentumOptimizerTest(XLATestCase): + + def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum): + var += accum * lr * momentum + accum = accum * momentum + g + var -= lr * accum + var -= accum * lr * momentum + return var, accum + + def testBasic(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + mom_opt = momentum_lib.MomentumOptimizer( + learning_rate=2.0, momentum=0.9) + mom_update = mom_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Check we have slots + self.assertEqual(["momentum"], mom_opt.get_slot_names()) + slot0 = mom_opt.get_slot(var0, "momentum") + self.assertEquals(slot0.get_shape(), var0.get_shape()) + self.assertFalse(slot0 in variables.trainable_variables()) + slot1 = mom_opt.get_slot(var1, "momentum") + self.assertEquals(slot1.get_shape(), var1.get_shape()) + self.assertFalse(slot1 in variables.trainable_variables()) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the momentum accumulators where 0. So we should see a normal + # update: v -= grad * learning_rate + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) + self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + # Step 2: the momentum accumulators contain the previous update. + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), + 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( + (0.9 * 0.01 + 0.01) * 2.0) + ]), var1.eval()) + + def testNesterovMomentum(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + var0_np = np.array([1.0, 2.0], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype) + cost = 5 * var0 * var0 + 3 * var1 + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int32), name="global_step") + mom_op = momentum_lib.MomentumOptimizer( + learning_rate=2.0, momentum=0.9, use_nesterov=True) + opt_op = mom_op.minimize(cost, global_step, [var0, var1]) + variables.global_variables_initializer().run() + for _ in range(1, 5): + opt_op.run() + var0_np, accum0_np = self._update_nesterov_momentum_numpy( + var0_np, accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, + 3, 2.0, 0.9) + self.assertAllClose(var0_np, var0.eval()) + self.assertAllClose(var1_np, var1.eval()) + + def testTensorLearningRateAndMomentum(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + mom_opt = momentum_lib.MomentumOptimizer( + learning_rate=constant_op.constant(2.0), + momentum=constant_op.constant(0.9)) + mom_update = mom_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Check we have slots + self.assertEqual(["momentum"], mom_opt.get_slot_names()) + slot0 = mom_opt.get_slot(var0, "momentum") + self.assertEquals(slot0.get_shape(), var0.get_shape()) + self.assertFalse(slot0 in variables.trainable_variables()) + slot1 = mom_opt.get_slot(var1, "momentum") + self.assertEquals(slot1.get_shape(), var1.get_shape()) + self.assertFalse(slot1 in variables.trainable_variables()) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the momentum accumulators where 0. So we should see a normal + # update: v -= grad * learning_rate + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) + self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + # Step 2: the momentum accumulators contain the previous update. + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), + 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( + (0.9 * 0.01 + 0.01) * 2.0) + ]), var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index d94e11b0789..2660e1d5728 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -75,6 +75,28 @@ class NAryOpsTest(XLATestCase): expected=np.array( [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32)) + def testOneHot(self): + with self.test_session() as session, self.test_scope(): + indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32)) + op = array_ops.one_hot(indices, + np.int32(4), + on_value=np.float32(7), off_value=np.float32(3)) + output = session.run(op) + expected = np.array([[[3, 3, 7, 3], [3, 3, 3, 7]], + [[7, 3, 3, 3], [3, 7, 3, 3]]], + dtype=np.float32) + self.assertAllEqual(output, expected) + + op = array_ops.one_hot(indices, + np.int32(4), + on_value=np.int32(2), off_value=np.int32(1), + axis=1) + output = session.run(op) + expected = np.array([[[1, 1], [1, 1], [2, 1], [1, 2]], + [[2, 1], [1, 2], [1, 1], [1, 1]]], + dtype=np.int32) + self.assertAllEqual(output, expected) + def testSplitV(self): with self.test_session() as session: with self.test_scope(): @@ -94,12 +116,14 @@ class NAryOpsTest(XLATestCase): np.array([1, 1], dtype=np.int32)], expected=np.array([[], []], dtype=np.float32)) - self._testNAry(lambda x: array_ops.strided_slice(*x), - [np.array([[], [], []], dtype=np.float32), - np.array([1, 0], dtype=np.int64), - np.array([3, 0], dtype=np.int64), - np.array([1, 1], dtype=np.int64)], - expected=np.array([[], []], dtype=np.float32)) + if np.int64 in self.int_types: + self._testNAry( + lambda x: array_ops.strided_slice(*x), [ + np.array([[], [], []], dtype=np.float32), np.array( + [1, 0], dtype=np.int64), np.array([3, 0], dtype=np.int64), + np.array([1, 1], dtype=np.int64) + ], + expected=np.array([[], []], dtype=np.float32)) self._testNAry(lambda x: array_ops.strided_slice(*x), [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], diff --git a/tensorflow/compiler/tests/plugin.bzl b/tensorflow/compiler/tests/plugin.bzl new file mode 100644 index 00000000000..b6eb7a9e395 --- /dev/null +++ b/tensorflow/compiler/tests/plugin.bzl @@ -0,0 +1,23 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional XLA devices to be included in the unit test suite.""" + +# If you wish to edit this file without checking it into the repo, consider: +# git update-index --assume-unchanged tensorflow/compiler/tests/plugin.bzl + +plugins = { + #"poplar": {"device":"XLA_IPU", "types":"DT_FLOAT,DT_INT32", "tags":[]}, +} + diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py new file mode 100644 index 00000000000..eb48fe555a0 --- /dev/null +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -0,0 +1,400 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for 3d pooling operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +# Wrapper around AvgPoolGrad that ignores extra arguments needed by +# MaxPoolGrad. +def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding): + del outputs # Unused by average-pooling gradients. + return gen_nn_ops._avg_pool3d_grad( + inputs.get_shape().as_list(), + output_gradients, + ksize=ksize, + strides=strides, + padding=padding) + + +class Pooling3DTest(XLATestCase): + + def _VerifyValues(self, pool_func, input_sizes, window, strides, padding, + expected): + """Verifies the output values of the pooling function. + + Args: + pool_func: Function to be called: co.MaxPool, co.AvgPool. + input_sizes: Input tensor dimensions. + window: Tuple of kernel dims: planes, rows, cols. + strides: Tuple of strides for dims: planes, rows, cols. + padding: Padding type. + expected: An array containing the expected operation outputs. + """ + total_size = 1 + for s in input_sizes: + total_size *= s + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x = np.arange(1.0, total_size + 1, dtype=np.float32) + x = x.reshape(input_sizes) + with self.test_session() as sess, self.test_scope(): + inputs = array_ops.placeholder(dtypes.float32) + t = pool_func( + inputs, + ksize=[1] + window + [1], + strides=[1] + strides + [1], + padding=padding) + vals = sess.run(t, {inputs: x}) + # Verifies values. + actual = vals.flatten() + self.assertAllClose(expected, actual) + + def testAvgPool3dValidPadding(self): + expected_output = [20.5, 21.5, 22.5] + self._VerifyValues( + nn_ops.avg_pool3d, + input_sizes=[1, 3, 3, 3, 3], + window=[2, 2, 2], + strides=[2, 2, 2], + padding="VALID", + expected=expected_output) + + def testAvgPool3dSamePadding(self): + expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5] + self._VerifyValues( + nn_ops.avg_pool3d, + input_sizes=[1, 2, 2, 4, 3], + window=[2, 2, 2], + strides=[2, 2, 2], + padding="SAME", + expected=expected_output) + + def testAvgPool3dSamePaddingDifferentStrides(self): + expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5] + self._VerifyValues( + nn_ops.avg_pool3d, + input_sizes=[1, 5, 8, 1, 1], + window=[1, 2, 3], + strides=[2, 3, 1], + padding="SAME", + expected=expected_output) + + def testMaxPool3dValidPadding(self): + expected_output = [40.0, 41.0, 42.0] + self._VerifyValues( + nn_ops.max_pool3d, + input_sizes=[1, 3, 3, 3, 3], + window=[2, 2, 2], + strides=[2, 2, 2], + padding="VALID", + expected=expected_output) + + def testMaxPool3dSamePadding(self): + expected_output = [31., 32., 33., 34., 35., 36.] + self._VerifyValues( + nn_ops.max_pool3d, + input_sizes=[1, 2, 2, 3, 3], + window=[2, 2, 2], + strides=[2, 2, 2], + padding="SAME", + expected=expected_output) + + def testMaxPool3dSamePaddingDifferentStrides(self): + expected_output = [2., 5., 8., 18., 21., 24., 34., 37., 40.] + self._VerifyValues( + nn_ops.max_pool3d, + input_sizes=[1, 5, 8, 1, 1], + window=[1, 2, 3], + strides=[2, 3, 1], + padding="SAME", + expected=expected_output) + + # Test pooling on a larger input, with different stride and kernel + # size for the 'z' dimension. + + # Simulate max pooling in numpy to get the expected output. + input_data = np.arange(1, 5 * 27 * 27 * 64 + 1).reshape((5, 27, 27, 64)) + input_data = np.pad(input_data, [[0, 0], [0, 1], [0, 1], [0, 0]], + mode="constant") + expected_output = input_data[:, 1::2, 1::2, :] + expected_output[:, -1, :, :] = input_data[:, -2, 1::2, :] + expected_output[:, :, -1, :] = input_data[:, 1::2, -2, :] + expected_output[:, -1, -1, :] = input_data[:, -2, -2, :] + + self._VerifyValues( + nn_ops.max_pool3d, + input_sizes=[1, 5, 27, 27, 64], + window=[1, 2, 2], + strides=[1, 2, 2], + padding="SAME", + expected=expected_output.flatten()) + + def testKernelSmallerThanStride(self): + self._VerifyValues( + nn_ops.max_pool3d, + input_sizes=[1, 3, 3, 3, 1], + window=[1, 1, 1], + strides=[2, 2, 2], + padding="SAME", + expected=[1, 3, 7, 9, 19, 21, 25, 27]) + + self._VerifyValues( + nn_ops.max_pool3d, + input_sizes=[1, 7, 7, 7, 1], + window=[2, 2, 2], + strides=[3, 3, 3], + padding="VALID", + expected=[58, 61, 79, 82, 205, 208, 226, 229]) + + self._VerifyValues( + nn_ops.avg_pool3d, + input_sizes=[1, 3, 3, 3, 1], + window=[1, 1, 1], + strides=[2, 2, 2], + padding="SAME", + expected=[1, 3, 7, 9, 19, 21, 25, 27]) + + self._VerifyValues( + nn_ops.avg_pool3d, + input_sizes=[1, 7, 7, 7, 1], + window=[2, 2, 2], + strides=[3, 3, 3], + padding="VALID", + expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5]) + + def _VerifyGradient(self, pool_func, pool_grad_func, input_sizes, ksize, + strides, padding): + """Verifies the output values of the pooling gradient function. + + Args: + pool_func: Forward pooling function + pool_grad_func: Pooling gradient function for pool_grad_func + input_sizes: Input tensor dimensions. + ksize: The kernel size dimensions + strides: The stride dimensions + padding: Padding type. + """ + ksize = [1] + ksize + [1] + strides = [1] + strides + [1] + total_size = np.prod(input_sizes) + x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) + with self.test_session() as sess: + # Use the forward pool function to compute some corresponding outputs + # (needed for the CPU device, and we need the shape in both cases). + with ops.device("CPU"): + inputs = array_ops.placeholder(dtypes.float32, shape=input_sizes) + outputs = pool_func( + inputs, + ksize=ksize, + strides=strides, + padding=padding) + + output_vals = np.array(sess.run(outputs, {inputs: x})) + output_gradient_vals = np.arange( + 1, output_vals.size + 1, dtype=np.float32) + output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) + + # Use the Tensorflow CPU pooling gradient to compute the expected input + # gradients. + with ops.device("CPU"): + output_gradients = array_ops.placeholder( + dtypes.float32, shape=output_vals.shape) + expected_input_gradients = pool_grad_func( + inputs, + outputs, + output_gradients, + ksize=ksize, + strides=strides, + padding=padding) + expected_input_gradient_vals = sess.run( + expected_input_gradients, + {inputs: x, + output_gradients: output_gradient_vals}) + + # Run the gradient op on the XLA device + with self.test_scope(): + outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) + actual_input_gradients = pool_grad_func( + inputs, + outputs, + output_gradients, + ksize=ksize, + strides=strides, + padding=padding) + actual = sess.run(actual_input_gradients, { + inputs: x, + outputs: output_vals, + output_gradients: output_gradient_vals + }) + + # Compare the Tensorflow and XLA results. + self.assertAllClose( + expected_input_gradient_vals.flatten(), + actual.flatten(), + rtol=1e-5, + atol=1e-6) + self.assertShapeEqual(actual, inputs) + + def testMaxPoolGradValidPadding1_1_3d(self): + self._VerifyGradient( + nn_ops.max_pool3d, + gen_nn_ops._max_pool3d_grad, + input_sizes=[1, 3, 3, 3, 1], + ksize=[1, 1, 1], + strides=[1, 1, 1], + padding="VALID") + + def testMaxPoolGradValidPadding2_1_6_3d(self): + self._VerifyGradient( + nn_ops.max_pool3d, + gen_nn_ops._max_pool3d_grad, + input_sizes=[2, 3, 3, 6, 3], + ksize=[2, 2, 2], + strides=[1, 1, 1], + padding="VALID") + + def testMaxPoolGradValidPadding2_1_7_3d(self): + self._VerifyGradient( + nn_ops.max_pool3d, + gen_nn_ops._max_pool3d_grad, + input_sizes=[2, 3, 5, 7, 3], + ksize=[2, 2, 2], + strides=[1, 1, 1], + padding="VALID") + + def testMaxPoolGradValidPadding2_2_3d(self): + self._VerifyGradient( + nn_ops.max_pool3d, + gen_nn_ops._max_pool3d_grad, + input_sizes=[2, 2, 2, 2, 3], + ksize=[2, 2, 2], + strides=[2, 2, 2], + padding="VALID") + + def testMaxPoolGradSamePadding1_1_3d(self): + self._VerifyGradient( + nn_ops.max_pool3d, + gen_nn_ops._max_pool3d_grad, + input_sizes=[2, 3, 2, 4, 1], + ksize=[1, 1, 1], + strides=[1, 1, 1], + padding="SAME") + + def testMaxPoolGradSamePadding2_1_3d(self): + self._VerifyGradient( + nn_ops.max_pool3d, + gen_nn_ops._max_pool3d_grad, + input_sizes=[2, 3, 2, 4, 1], + ksize=[2, 2, 2], + strides=[1, 1, 1], + padding="SAME") + + def testMaxPoolGradSamePadding2_2_3d(self): + self._VerifyGradient( + nn_ops.max_pool3d, + gen_nn_ops._max_pool3d_grad, + input_sizes=[2, 5, 2, 4, 3], + ksize=[2, 2, 2], + strides=[2, 2, 2], + padding="SAME") + + def testMaxPoolGradSamePadding3_1_3d(self): + self._VerifyGradient( + nn_ops.max_pool3d, + gen_nn_ops._max_pool3d_grad, + input_sizes=[1, 3, 3, 7, 1], + ksize=[3, 3, 3], + strides=[1, 1, 1], + padding="SAME") + + def testAvgPoolGradValidPadding1_1_3d(self): + self._VerifyGradient( + nn_ops.avg_pool3d, + _AvgPoolGrad, + input_sizes=[2, 3, 3, 3, 3], + ksize=[1, 1, 1], + strides=[1, 1, 1], + padding="VALID") + + def testAvgPoolGradValidPadding2_1_3d(self): + self._VerifyGradient( + nn_ops.avg_pool3d, + _AvgPoolGrad, + input_sizes=[2, 3, 3, 3, 3], + ksize=[2, 2, 2], + strides=[1, 1, 1], + padding="VALID") + + def testAvgPoolGradValidPadding2_2_3d(self): + self._VerifyGradient( + nn_ops.avg_pool3d, + _AvgPoolGrad, + input_sizes=[2, 2, 2, 2, 3], + ksize=[2, 2, 2], + strides=[2, 2, 2], + padding="VALID") + + def testAvgPoolGradSamePadding1_1_3d(self): + self._VerifyGradient( + nn_ops.avg_pool3d, + _AvgPoolGrad, + input_sizes=[2, 3, 2, 4, 3], + ksize=[1, 1, 1], + strides=[1, 1, 1], + padding="SAME") + + def testAvgPoolGradSamePadding2_1_3d(self): + self._VerifyGradient( + nn_ops.avg_pool3d, + _AvgPoolGrad, + input_sizes=[1, 2, 2, 2, 1], + ksize=[2, 2, 2], + strides=[1, 1, 1], + padding="SAME") + + def testAvgPoolGradSamePadding2_2_3d(self): + self._VerifyGradient( + nn_ops.avg_pool3d, + _AvgPoolGrad, + input_sizes=[2, 5, 2, 4, 3], + ksize=[2, 2, 2], + strides=[2, 2, 2], + padding="SAME") + + def testAvgPoolGradSamePadding3_1_3d(self): + self._VerifyGradient( + nn_ops.avg_pool3d, + _AvgPoolGrad, + input_sizes=[1, 3, 6, 7, 1], + ksize=[3, 3, 3], + strides=[1, 1, 1], + padding="SAME") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index b7c8b3f5980..a17a3f3d653 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -72,6 +72,17 @@ class RandomOpsTest(XLATestCase): self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) + def testTruncatedNormalIsInRange(self): + count = 10000 + # TODO(b/34339814): implement inverse erf support for non-F32 types. + for dtype in [dtypes.float32]: + with self.test_session() as sess: + with self.test_scope(): + x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42) + y = sess.run(x) + self.assertTrue((y >= -2).sum() == count) + self.assertTrue((y <= 2).sum() == count) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index abc0cb2cce7..d3821ad02e5 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -68,6 +68,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -75,6 +76,7 @@ namespace { // Command line flags: see main() below. int64 tf_xla_random_seed = 0; int32 tf_xla_test_repetitions = 20; +int64 tf_xla_max_tensor_size = 100000LL; string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; @@ -93,7 +95,12 @@ class OpTestBuilder { explicit OpTestBuilder(const string& op_name); // Adds an input 'tensor'. - OpTestBuilder& Input(Tensor tensor); + OpTestBuilder& Input(const Tensor& tensor); + + // Adds a random input tensor with 'type'. If 'dims' is not provided, + // RandomDims() is used. + OpTestBuilder& RandomInput(DataType type); + OpTestBuilder& RandomInput(DataType type, std::vector dims); // Sets an attribute. template @@ -110,25 +117,54 @@ class OpTestBuilder { // sets it to the NodeDef of the operator under test. Fills 'inputs' and // 'outputs' with the names of the input placeholder nodes and the output // identity nodes, respectively. - Status BuildGraph(string name_prefix, string device, bool use_jit, - GraphDef* graphdef, NodeDef** test_node_def, + Status BuildGraph(const string& name_prefix, const string& device, + bool use_jit, GraphDef* graphdef, NodeDef** test_node_def, std::vector* inputs, std::vector* outputs) const; - const std::vector& inputs() const { return inputs_; } + struct InputDescription { + Tensor tensor; + + DataType type = DT_INVALID; + bool has_dims = false; + std::vector dims; + }; + + const std::vector& inputs() const { return inputs_; } private: NodeDef node_def_; - std::vector inputs_; + std::vector inputs_; }; OpTestBuilder::OpTestBuilder(const string& op_name) { node_def_.set_op(op_name); } -OpTestBuilder& OpTestBuilder::Input(Tensor tensor) { +OpTestBuilder& OpTestBuilder::Input(const Tensor& tensor) { VLOG(1) << "Adding input: " << tensor.DebugString(); - inputs_.push_back(tensor); + InputDescription input; + input.tensor = tensor; + inputs_.push_back(input); + return *this; +} + +OpTestBuilder& OpTestBuilder::RandomInput(DataType type) { + VLOG(1) << "Adding random input: " << type; + InputDescription input; + input.type = type; + inputs_.push_back(input); + return *this; +} + +OpTestBuilder& OpTestBuilder::RandomInput(DataType type, + std::vector dims) { + VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString(); + InputDescription input; + input.type = type; + input.has_dims = true; + input.dims = std::move(dims); + inputs_.push_back(input); return *this; } @@ -145,9 +181,9 @@ OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, return *this; } -Status OpTestBuilder::BuildGraph(string name_prefix, string device, - bool use_jit, GraphDef* graphdef, - NodeDef** test_node_def, +Status OpTestBuilder::BuildGraph(const string& name_prefix, + const string& device, bool use_jit, + GraphDef* graphdef, NodeDef** test_node_def, std::vector* inputs, std::vector* outputs) const { OpRegistryInterface* op_registry = OpRegistry::Global(); @@ -206,23 +242,36 @@ class OpTest : public ::testing::Test { public: OpTest(); - // Runs 'fn' up to --tf_xla_test_repetitions times, or until a failure occurs; - // whichever happens first. - void Repeatedly(std::function fn); + enum TestResult { + // The test saw an unrecoverable error. Don't try any more runs. + kFatalError, + // The parameters of the test were invalid (e.g., the "golden" + // implementation failed, or the parameters are oversize). Reruns are ok. + kInvalid, + // The test ran successfully, and we have a verdict. Does *not* mean the + // test passed. + kOk, + }; + + // Runs 'fn' up to --tf_xla_test_repetitions times, or until a test failure + // occurs; whichever happens first. Reruns if the TestResult is kInvalid. + void Repeatedly(const std::function& fn); // Select a random element from 'candidates'. template T Choose(gtl::ArraySlice candidates); static constexpr int kDefaultMaxRank = 5; - static constexpr int64 kDefaultMaxDimensionSize = 20LL; + static constexpr int64 kDefaultMaxDimensionSize = 256LL; - // Returns a random dimension size. + // Returns true if 'dims' have a size less than tf_xla_max_tensor_size. + bool TensorSizeIsOk(gtl::ArraySlice dims); + + // Returns a random dimension size, in the range [min, max). int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); // Returns a random shape. The tensor has rank in the range [min_rank, - // max_rank). - // Each dimension has size [0, kDefaultMaxDimensionSize]. + // max_rank). Each dimension has size [min_size, max_size). std::vector RandomDims(int min_rank = 0, int max_rank = kDefaultMaxRank, int64 min_size = 0, @@ -252,17 +301,23 @@ class OpTest : public ::testing::Test { // for use as reduction indices. Tensor RandomReductionIndices(int rank); - struct WindowedDims { + struct WindowedSpatialDims { Padding padding; - int kernel_rows, kernel_cols; - int stride_rows, stride_cols; - int input_rows, input_cols; - int64 output_rows, output_cols; + std::vector kernel_dims; + std::vector stride_dims; + std::vector input_dims; + std::vector output_dims; }; - // Choose dimensions for a 2D windowed op such as pooling or convolution. - // TODO(phawkins): currently this only produces spatial windows, in NHWC - // format. - WindowedDims ChooseWindowedDims(); + // Choose spatial dimensions for a windowed op such as pooling or convolution. + WindowedSpatialDims ChooseWindowedSpatialDims(int num_spatial_dims); + + // Builds dimensions for a windowed op such as pooling or convolution, + // including a batch and feature dimension. + std::vector ImageDims(TensorFormat format, int batch, int feature, + const std::vector& spatial_dims); + + // Converts an int64 vector to an int32 vector. + std::vector AsInt32s(const std::vector& int64s); std::mt19937& generator() { return *generator_; } @@ -272,8 +327,9 @@ class OpTest : public ::testing::Test { // element-wise difference between x and y must no more than // atol + rtol * abs(x); or both elements may be NaN or infinity. For // non-floating-point tensors the element values must match exactly. - void ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, - double atol = 1e-2, double rtol = 1e-2); + TestResult ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, + double atol = 1e-2, + double rtol = 1e-2); protected: // Per-test state: @@ -309,10 +365,35 @@ OpTest::OpTest() { TF_CHECK_OK(session_->Create(def)); } -void OpTest::Repeatedly(std::function fn) { +void OpTest::Repeatedly(const std::function& fn) { int const max_repetitions = tf_xla_test_repetitions; - for (int i = 0; !HasFailure() && i < max_repetitions; ++i) { - fn(); + int valid_test_runs = 0; + // We run up to 10 * max_repetitions times; the idea is that if we roll the + // dice enough times we will find some valid parameters. We want to put an + // upper limit on the number iterations just in case the probability of + // finding feasible parameters is very low. + for (int i = 0; !HasFailure() && i < max_repetitions * 10 && + valid_test_runs < max_repetitions; + ++i) { + TestResult result = fn(); + switch (result) { + case kOk: + ++valid_test_runs; + break; + + case kFatalError: + ASSERT_TRUE(false) << "Test had fatal failure"; + return; + + case kInvalid: + break; + } + } + if (!HasFailure()) { + EXPECT_GE(valid_test_runs, max_repetitions) + << "Not enough test instances passed; this means that either the " + "golden implementation is buggy or the operator harness is not " + "producing well-formed test cases with a high probability."; } } @@ -327,6 +408,14 @@ int64 OpTest::RandomDim(int64 min, int64 max) { return size_distribution(generator()); } +bool OpTest::TensorSizeIsOk(gtl::ArraySlice dims) { + int64 size = 1LL; + for (int64 dim : dims) { + size *= dim; + } + return size < tf_xla_max_tensor_size; +} + std::vector OpTest::RandomDims(int min_rank, int max_rank, int64 min_size, int64 max_size) { CHECK_LE(0, min_rank); @@ -334,9 +423,13 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, std::uniform_int_distribution rank_distribution(min_rank, max_rank); int rank = rank_distribution(generator()); std::vector dims(rank); - std::generate(dims.begin(), dims.end(), [this, min_size, max_size]() { - return RandomDim(min_size, max_size); - }); + // TODO(phawkins): too small a maximum tensor size could lead to an infinite + // loop here. + do { + std::generate(dims.begin(), dims.end(), [this, min_size, max_size]() { + return RandomDim(min_size, max_size); + }); + } while (!TensorSizeIsOk(dims)); return dims; } @@ -473,35 +566,63 @@ Tensor OpTest::RandomReductionIndices(int rank) { return test::AsTensor(indices); } -OpTest::WindowedDims OpTest::ChooseWindowedDims() { - WindowedDims d; +OpTest::WindowedSpatialDims OpTest::ChooseWindowedSpatialDims( + int num_spatial_dims) { + WindowedSpatialDims d; d.padding = Choose({SAME, VALID}); std::uniform_int_distribution random_int(1, 5); - Status s; - // Repeatedly try different filter/stride sizes until we find a valid - // combination. - do { - // CPU implementations require stride <= kernel size. - d.kernel_rows = random_int(generator()), - d.input_rows = RandomDim(d.kernel_rows); - d.stride_rows = - std::uniform_int_distribution(1, d.kernel_rows)(generator()); - int64 pad_dummy; - s = GetWindowedOutputSize(d.input_rows, d.kernel_rows, d.stride_rows, - d.padding, &d.output_rows, &pad_dummy); - } while (!s.ok()); - do { - d.kernel_cols = random_int(generator()); - d.input_cols = RandomDim(d.kernel_cols); - d.stride_cols = - std::uniform_int_distribution(1, d.kernel_cols)(generator()); - int64 pad_dummy; - s = GetWindowedOutputSize(d.input_cols, d.kernel_cols, d.stride_cols, - d.padding, &d.output_cols, &pad_dummy); - } while (!s.ok()); + d.kernel_dims.resize(num_spatial_dims); + d.input_dims.resize(num_spatial_dims); + d.output_dims.resize(num_spatial_dims); + d.stride_dims.resize(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + Status s; + // Repeatedly try different filter/stride sizes until we find a valid + // combination. + do { + // CPU implementations require stride <= kernel size. + d.kernel_dims[i] = random_int(generator()), + d.input_dims[i] = RandomDim(d.kernel_dims[i]); + d.stride_dims[i] = + std::uniform_int_distribution(1, d.kernel_dims[i])(generator()); + int64 pad_dummy; + s = GetWindowedOutputSize(d.input_dims[i], d.kernel_dims[i], + d.stride_dims[i], d.padding, &d.output_dims[i], + &pad_dummy); + } while (!s.ok()); + } return d; } +std::vector OpTest::ImageDims(TensorFormat format, int batch, + int feature, + const std::vector& spatial_dims) { + std::vector dims; + switch (format) { + case FORMAT_NHWC: + dims.push_back(batch); + for (int dim : spatial_dims) { + dims.push_back(dim); + } + dims.push_back(feature); + break; + case FORMAT_NCHW: + dims.push_back(batch); + dims.push_back(feature); + for (int dim : spatial_dims) { + dims.push_back(dim); + } + break; + case FORMAT_NCHW_VECT_C: + LOG(FATAL) << "FORMAT_NCHW_VECT_C not supported."; + } + return dims; +} + +std::vector OpTest::AsInt32s(const std::vector& int64s) { + return std::vector(int64s.begin(), int64s.end()); +} + // Functions for comparing tensors. template @@ -574,53 +695,84 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, } } -void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, - double atol, double rtol) { +OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( + const OpTestBuilder& builder, double atol, double rtol) { + const std::vector& inputs = builder.inputs(); + std::vector input_tensors; + input_tensors.reserve(inputs.size()); + for (const OpTestBuilder::InputDescription& input : inputs) { + if (input.type == DT_INVALID) { + VLOG(1) << "Input: " << input.tensor.DebugString(); + input_tensors.push_back(input.tensor); + } else { + VLOG(1) << "Input: " << input.type << " " + << TensorShape(input.dims).DebugString(); + std::vector dims; + if (input.has_dims) { + dims = input.dims; + } else { + dims = RandomDims(); + } + if (!TensorSizeIsOk(dims)) { + VLOG(1) << "Ignoring oversize dims."; + return kInvalid; + } + input_tensors.push_back(RandomTensor(input.type, dims)); + } + } + string cpu_device = LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0")); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; - ASSERT_TRUE( - DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)); + if (!DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)) { + LOG(ERROR) << "Could not parse device name: " << *tf_xla_test_device_ptr; + return kFatalError; + } DeviceType test_device_type(parsed_name.type); ++num_tests_; GraphDef graph; std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; - TF_ASSERT_OK(builder.BuildGraph( + Status status = builder.BuildGraph( strings::StrCat("test", num_tests_, "_expected"), cpu_device, /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, - &expected_inputs, &expected_fetches)); + &expected_inputs, &expected_fetches); + if (!status.ok()) { + LOG(ERROR) << "Expected graph construction failed: " << status; + return kFatalError; + } NodeDef* node_def; - TF_ASSERT_OK(builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), - test_device, tf_xla_test_use_jit, &graph, - &node_def, &test_inputs, &test_fetches)); + status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), + test_device, tf_xla_test_use_jit, &graph, + &node_def, &test_inputs, &test_fetches); + if (!status.ok()) { + LOG(ERROR) << "Test graph construction failed: " << status; + return kFatalError; + } // Check that there's a kernel corresponding to 'node_def' on the device under // test. - Status status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr); + status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr); if (!status.ok()) { VLOG(1) << "Skipping test because there is no corresponding registered " << "kernel on the test device: " << status; - return; + return kInvalid; } - TF_ASSERT_OK(session_->Extend(graph)); - - const std::vector& input_tensors = builder.inputs(); - if (VLOG_IS_ON(1)) { - for (const Tensor& input : input_tensors) { - VLOG(1) << "Input: " << input.DebugString(); - } + status = session_->Extend(graph); + if (!status.ok()) { + LOG(ERROR) << "Session::Extend() failed: " << status; + return kFatalError; } std::vector> expected_feeds(expected_inputs.size()); std::vector> test_feeds(test_inputs.size()); - ASSERT_EQ(input_tensors.size(), expected_inputs.size()); - ASSERT_EQ(input_tensors.size(), test_inputs.size()); + CHECK_EQ(input_tensors.size(), expected_inputs.size()); + CHECK_EQ(input_tensors.size(), test_inputs.size()); for (int i = 0; i < input_tensors.size(); ++i) { expected_feeds[i] = {expected_inputs[i], input_tensors[i]}; @@ -632,18 +784,27 @@ void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, Status s = session_->Run(expected_feeds, expected_fetches, {}, &expected_outputs); if (!s.ok()) { - VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test"; - return; + VLOG(1) << "Expected graph failed with status: " << s << ". Ignoring test"; + return kInvalid; + } + for (const Tensor& expected : expected_outputs) { + VLOG(1) << "Expected: " << expected.DebugString(); } VLOG(1) << "Running test graph"; - TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs)); + status = session_->Run(test_feeds, test_fetches, {}, &test_outputs); + if (!status.ok()) { + LOG(ERROR) << "Test graph failed: " << status; + return kFatalError; + } - ASSERT_EQ(expected_outputs.size(), test_outputs.size()); + CHECK_EQ(expected_outputs.size(), test_outputs.size()); for (int j = 0; s.ok() && j < test_outputs.size(); ++j) { s = TensorsAreClose(expected_outputs[j], test_outputs[j], atol, rtol); } TF_EXPECT_OK(s); + + return kOk; } // Helper that converts 'values' to an int32 or int64 Tensor. @@ -663,8 +824,8 @@ Tensor AsIntTensor(DataType dtype, const std::vector& values) { TEST_F(OpTest, Abs) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Abs").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Abs").RandomInput(type).Attr("T", type)); }); } @@ -672,10 +833,10 @@ TEST_F(OpTest, Add) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -690,47 +851,50 @@ TEST_F(OpTest, AddN) { builder.Attr("T", type); builder.Attr("N", n); for (int i = 0; i < n; ++i) { - builder.Input(RandomTensor(type, shape)); + builder.RandomInput(type, shape); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } TEST_F(OpTest, All) { Repeatedly([this]() { - Tensor data = RandomTensor(DT_BOOL); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("All").Input(data).Input(indices).Attr("keep_dims", - keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("All") + .RandomInput(DT_BOOL, data_dims) + .Input(indices) + .Attr("keep_dims", keep_dims)); }); } TEST_F(OpTest, Any) { Repeatedly([this]() { - Tensor data = RandomTensor(DT_BOOL); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Any").Input(data).Input(indices).Attr("keep_dims", - keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Any") + .RandomInput(DT_BOOL, data_dims) + .Input(indices) + .Attr("keep_dims", keep_dims)); }); } TEST_F(OpTest, AvgPool) { Repeatedly([this]() { std::uniform_int_distribution random_int(1, 5); - int kernel_rows = random_int(generator()), - kernel_cols = random_int(generator()); + std::vector dims = RandomDims(4, 4, 1); + int kernel_rows = + std::uniform_int_distribution(1, dims[1])(generator()); + int kernel_cols = + std::uniform_int_distribution(1, dims[2])(generator()); int stride_rows = random_int(generator()), stride_cols = random_int(generator()); string padding = Choose({"SAME", "VALID"}); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool") - .Input( - RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows), - RandomDim(kernel_cols), RandomDim(1)})) + .RandomInput(DT_FLOAT, dims) .Attr("T", DT_FLOAT) .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) .Attr("strides", {1, stride_rows, stride_cols, 1}) @@ -741,21 +905,72 @@ TEST_F(OpTest, AvgPool) { // for batch pooling when supported. } +TEST_F(OpTest, AvgPool3D) { + Repeatedly([this]() { + std::uniform_int_distribution random_int(1, 5); + std::vector dims = RandomDims(5, 5, 1); + + std::vector input_dims, kernel_dims, stride_dims; + for (int i = 0; i < 3; ++i) { + kernel_dims.push_back( + std::uniform_int_distribution(1, dims[i])(generator())); + input_dims.push_back(dims[i]); + stride_dims.push_back(random_int(generator())); + } + int64 batch = dims[3]; + int64 feature = dims[4]; + + string padding = Choose({"SAME", "VALID"}); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("AvgPool3D") + .RandomInput(DT_FLOAT, + ImageDims(FORMAT_NHWC, batch, feature, input_dims)) + .Attr("T", DT_FLOAT) + .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, kernel_dims)) + .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, stride_dims)) + .Attr("padding", padding) + .Attr("data_format", "NDHWC")); + }); + // TODO(phawkins): test NCHW format (not supported by CPU) +} + TEST_F(OpTest, AvgPoolGrad) { Repeatedly([this]() { int batch = RandomDim(1), features = RandomDim(1); - WindowedDims d = ChooseWindowedDims(); - ExpectTfAndXlaOutputsAreClose( + WindowedSpatialDims d = ChooseWindowedSpatialDims(2); + std::vector input_dims = + AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims)); + std::vector output_dims = + ImageDims(FORMAT_NHWC, batch, features, d.output_dims); + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPoolGrad") - .Input(test::AsTensor( - {batch, d.input_rows, d.input_cols, features})) - .Input(RandomTensor( - DT_FLOAT, {batch, d.output_rows, d.output_cols, features})) + .Input(test::AsTensor(input_dims)) + .RandomInput(DT_FLOAT, output_dims) .Attr("T", DT_FLOAT) - .Attr("ksize", {1, d.kernel_rows, d.kernel_cols, 1}) - .Attr("strides", {1, d.stride_rows, d.stride_cols, 1}) + .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims)) + .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") - .Attr("data_format", "NHWC")); + .Attr("data_format", "NDHWC")); + }); +} + +TEST_F(OpTest, AvgPool3DGrad) { + Repeatedly([this]() { + int batch = RandomDim(1), features = RandomDim(1); + WindowedSpatialDims d = ChooseWindowedSpatialDims(3); + std::vector input_dims = + AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims)); + std::vector output_dims = + ImageDims(FORMAT_NHWC, batch, features, d.output_dims); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("AvgPool3DGrad") + .Input(test::AsTensor(input_dims)) + .RandomInput(DT_FLOAT, output_dims) + .Attr("T", DT_FLOAT) + .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims)) + .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("padding", d.padding == SAME ? "SAME" : "VALID") + .Attr("data_format", "NDHWC")); }); } @@ -767,60 +982,127 @@ TEST_F(OpTest, BatchMatMul) { std::vector x_dims(output_dims), y_dims(output_dims); x_dims[ndims - 1] = inner_dim; y_dims[ndims - 2] = inner_dim; - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .Input(RandomTensor(DT_FLOAT, x_dims)) - .Input(RandomTensor(DT_FLOAT, y_dims)) - .Attr("T", DT_FLOAT)); - std::swap(x_dims[ndims - 1], x_dims[ndims - 2]); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .Input(RandomTensor(DT_FLOAT, x_dims)) - .Input(RandomTensor(DT_FLOAT, y_dims)) - .Attr("T", DT_FLOAT) - .Attr("adj_x", true)); + std::bernoulli_distribution random_bool; + bool adj_x = random_bool(generator()); + bool adj_y = random_bool(generator()); + if (adj_x) { + std::swap(x_dims[ndims - 1], x_dims[ndims - 2]); + } + if (adj_y) { + std::swap(y_dims[ndims - 1], y_dims[ndims - 2]); + } - std::swap(y_dims[ndims - 1], y_dims[ndims - 2]); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .Input(RandomTensor(DT_FLOAT, x_dims)) - .Input(RandomTensor(DT_FLOAT, y_dims)) - .Attr("T", DT_FLOAT) - .Attr("adj_x", true) - .Attr("adj_y", true)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") + .RandomInput(DT_FLOAT, x_dims) + .RandomInput(DT_FLOAT, y_dims) + .Attr("T", DT_FLOAT) + .Attr("adj_x", adj_x) + .Attr("adj_y", adj_y)); + }); +} - std::swap(x_dims[ndims - 1], x_dims[ndims - 2]); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .Input(RandomTensor(DT_FLOAT, x_dims)) - .Input(RandomTensor(DT_FLOAT, y_dims)) - .Attr("T", DT_FLOAT) - .Attr("adj_y", true)); +TEST_F(OpTest, BatchToSpace) { + Repeatedly([this]() { + const int num_block_dims = 2; + std::vector block_dims = + RandomDims(num_block_dims, num_block_dims, 0, 5); + int64 block_size = RandomDim(0, 4); + + std::vector input_dims(1 + num_block_dims + 1); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[0] *= block_size; + input_dims[1 + i] = block_dims[i]; + } + input_dims[1 + num_block_dims] = RandomDim(); + + std::vector crop_vals; + std::uniform_int_distribution distribution(0, 4); + for (int i = 0; i < num_block_dims; ++i) { + // Chooses crop values; does not always choose legal values. + crop_vals.push_back(distribution(generator())); + crop_vals.push_back(distribution(generator())); + } + Tensor crops; + CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), + TensorShape({num_block_dims, 2}))); + + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace") + .RandomInput(DT_FLOAT, input_dims) + .Input(crops) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); + }); +} + +TEST_F(OpTest, BatchToSpaceND) { + Repeatedly([this]() { + std::vector block_dims = RandomDims(1, 3, 0, 5); + int num_block_dims = block_dims.size(); + std::vector remaining_dims = RandomDims(0, 3); + std::vector block_multipliers = + RandomDims(block_dims.size(), block_dims.size(), 0, 4); + + std::vector input_dims(1 + num_block_dims + remaining_dims.size()); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[0] *= block_dims[i]; + } + std::copy(block_multipliers.begin(), block_multipliers.end(), + input_dims.begin() + 1); + std::copy(remaining_dims.begin(), remaining_dims.end(), + input_dims.begin() + 1 + num_block_dims); + + std::vector crop_vals; + std::uniform_int_distribution distribution(0, 3); + for (int i = 0; i < num_block_dims; ++i) { + // Chooses crop values; does not always choose legal values. + crop_vals.push_back(distribution(generator())); + crop_vals.push_back(distribution(generator())); + } + Tensor crops; + CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), + TensorShape({num_block_dims, 2}))); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BatchToSpaceND") + .RandomInput(DT_FLOAT, input_dims) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) + .Input(crops) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, BiasAdd) { Repeatedly([this]() { - auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank)); - auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)}); + auto x_dims = RandomDims(2, kDefaultMaxRank); + auto y_dims = {x_dims[x_dims.size() - 1]}; // TODO(phawkins): test both data formats. - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BiasAdd").Input(x).Input(y).Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd") + .RandomInput(DT_FLOAT, x_dims) + .RandomInput(DT_FLOAT, y_dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, BiasAddGrad) { Repeatedly([this]() { - auto x = RandomTensor(DT_FLOAT); // TODO(phawkins): test both data formats. - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BiasAddGrad").Input(x).Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BiasAddGrad").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, BiasAddV1) { Repeatedly([this]() { - auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank)); - auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BiasAddV1").Input(x).Input(y).Attr("T", DT_FLOAT)); + auto x_dims = RandomDims(2, kDefaultMaxRank); + auto y_dims = {x_dims[x_dims.size() - 1]}; + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1") + .RandomInput(DT_FLOAT, x_dims) + .RandomInput(DT_FLOAT, y_dims) + .Attr("T", DT_FLOAT)); }); } @@ -830,10 +1112,11 @@ TEST_F(OpTest, BroadcastGradientArgs) { // DataType type = Choose({DT_INT32, DT_INT64}); DataType type = DT_INT32; auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BroadcastGradientArgs") - .Input(AsIntTensor(type, dims.first)) - .Input(AsIntTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BroadcastGradientArgs") + .Input(AsIntTensor(type, dims.first)) + .Input(AsIntTensor(type, dims.second)) + .Attr("T", type)); }); } @@ -842,18 +1125,17 @@ TEST_F(OpTest, Cast) { DataType src_type, dst_type; src_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); dst_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") - .Input(RandomTensor(src_type)) - .Attr("SrcT", src_type) - .Attr("DstT", dst_type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") + .RandomInput(src_type) + .Attr("SrcT", src_type) + .Attr("DstT", dst_type)); }); } TEST_F(OpTest, Ceil) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Ceil") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Ceil").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } @@ -873,9 +1155,9 @@ TEST_F(OpTest, Concat) { for (int i = 0; i < n; ++i) { std::vector shape = dims; shape[concat_dim] = RandomDim(); - builder.Input(RandomTensor(type, shape)); + builder.RandomInput(type, shape); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } @@ -895,27 +1177,30 @@ TEST_F(OpTest, ConcatOffset) { shape[concat_dim] = RandomDim(); builder.Input(test::AsTensor(shape)); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } TEST_F(OpTest, Conv2D) { Repeatedly([this]() { - WindowedDims d = ChooseWindowedDims(); + WindowedSpatialDims d = ChooseWindowedSpatialDims(2); std::uniform_int_distribution random_int(1, 5); int features_in = random_int(generator()); int features_out = random_int(generator()); - Tensor data = RandomTensor( - DT_FLOAT, {RandomDim(), d.input_rows, d.input_cols, features_in}); - Tensor kernel = RandomTensor( - DT_FLOAT, {d.kernel_rows, d.kernel_cols, features_in, features_out}); - ExpectTfAndXlaOutputsAreClose( + int64 batch = RandomDim(); + + std::vector data_dims = + ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); + + std::vector kernel_dims = {d.kernel_dims[0], d.kernel_dims[1], + features_in, features_out}; + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2D") - .Input(data) - .Input(kernel) + .RandomInput(DT_FLOAT, data_dims) + .RandomInput(DT_FLOAT, kernel_dims) .Attr("T", DT_FLOAT) - .Attr("strides", {1, d.stride_rows, d.stride_cols, 1}) + .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); }); @@ -923,24 +1208,24 @@ TEST_F(OpTest, Conv2D) { TEST_F(OpTest, Conv2DBackpropFilter) { Repeatedly([this]() { - WindowedDims d = ChooseWindowedDims(); + WindowedSpatialDims d = ChooseWindowedSpatialDims(2); std::uniform_int_distribution random_int(1, 5); int features_in = random_int(generator()); int features_out = random_int(generator()); int32 batch = RandomDim(); - Tensor activations = RandomTensor( - DT_FLOAT, {batch, d.input_rows, d.input_cols, features_in}); - Tensor backprop = RandomTensor( - DT_FLOAT, {batch, d.output_rows, d.output_cols, features_out}); - Tensor kernel_shape = test::AsTensor( - {d.kernel_rows, d.kernel_cols, features_in, features_out}); - ExpectTfAndXlaOutputsAreClose( + std::vector activations = + ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); + std::vector backprop = + ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); + Tensor kernel_shape = test::AsTensor(AsInt32s( + {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out})); + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropFilter") - .Input(activations) + .RandomInput(DT_FLOAT, activations) .Input(kernel_shape) - .Input(backprop) + .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) - .Attr("strides", {1, d.stride_rows, d.stride_cols, 1}) + .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); }); @@ -948,35 +1233,111 @@ TEST_F(OpTest, Conv2DBackpropFilter) { TEST_F(OpTest, Conv2DBackpropInput) { Repeatedly([this]() { - WindowedDims d = ChooseWindowedDims(); + WindowedSpatialDims d = ChooseWindowedSpatialDims(2); std::uniform_int_distribution random_int(1, 5); int features_in = random_int(generator()); int features_out = random_int(generator()); int32 batch = RandomDim(); - Tensor in_shape = - test::AsTensor({batch, d.input_rows, d.input_cols, features_in}); - Tensor backprop = RandomTensor( - DT_FLOAT, {batch, d.output_rows, d.output_cols, features_out}); - Tensor kernel = RandomTensor( - DT_FLOAT, {d.kernel_rows, d.kernel_cols, features_in, features_out}); - ExpectTfAndXlaOutputsAreClose( + Tensor in_shape = test::AsTensor( + AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); + std::vector backprop = + ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); + std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], + features_in, features_out}; + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropInput") .Input(in_shape) - .Input(kernel) - .Input(backprop) + .RandomInput(DT_FLOAT, kernel) + .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) - .Attr("strides", {1, d.stride_rows, d.stride_cols, 1}) + .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); }); } +TEST_F(OpTest, Conv3D) { + Repeatedly([this]() { + WindowedSpatialDims d = ChooseWindowedSpatialDims(3); + std::uniform_int_distribution random_int(1, 5); + int features_in = random_int(generator()); + int features_out = random_int(generator()); + std::vector data = {RandomDim(), d.input_dims[0], d.input_dims[1], + d.input_dims[2], features_in}; + + std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], + d.kernel_dims[2], features_in, features_out}; + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Conv3D") + .RandomInput(DT_FLOAT, data) + .RandomInput(DT_FLOAT, kernel) + .Attr("T", DT_FLOAT) + .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); + }); +} + +TEST_F(OpTest, Conv3DBackpropFilter) { + Repeatedly([this]() { + WindowedSpatialDims d = ChooseWindowedSpatialDims(3); + std::uniform_int_distribution random_int(1, 5); + int features_in = random_int(generator()); + int features_out = random_int(generator()); + int32 batch = RandomDim(1); + std::vector activations = + ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); + std::vector backprop = + ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); + Tensor kernel_shape = test::AsTensor( + AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], + features_in, features_out})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Conv3DBackpropFilterV2") + .RandomInput(DT_FLOAT, activations) + .Input(kernel_shape) + .RandomInput(DT_FLOAT, backprop) + .Attr("T", DT_FLOAT) + .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); + }); +} + +TEST_F(OpTest, Conv3DBackpropInput) { + Repeatedly([this]() { + WindowedSpatialDims d = ChooseWindowedSpatialDims(3); + std::uniform_int_distribution random_int(1, 5); + int features_in = random_int(generator()); + int features_out = random_int(generator()); + int32 batch = RandomDim(1); + Tensor in_shape = test::AsTensor( + AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); + std::vector backprop = + ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); + std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], + d.kernel_dims[2], features_in, features_out}; + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Conv3DBackpropInputV2") + .Input(in_shape) + .RandomInput(DT_FLOAT, kernel) + .RandomInput(DT_FLOAT, backprop) + .Attr("T", DT_FLOAT) + .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); + }); +} + TEST_F(OpTest, Diag) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Diag") - .Input(RandomTensor(type, RandomDims(1))) - .Attr("T", type)); + std::vector dims; + // Diag causes a quadratic blowup in output size. + int64 size; + do { + dims = RandomDims(1); + size = TensorShape(dims).num_elements(); + } while (size * size < tf_xla_max_tensor_size); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type)); }); } @@ -988,9 +1349,9 @@ TEST_F(OpTest, DiagPart) { std::vector doubled_dims(dims.size() * 2); std::copy(dims.begin(), dims.end(), doubled_dims.begin()); std::copy(dims.begin(), dims.end(), doubled_dims.begin() + dims.size()); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart") - .Input(RandomTensor(type, doubled_dims)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart") + .RandomInput(type, doubled_dims) + .Attr("T", type)); }); } @@ -998,10 +1359,10 @@ TEST_F(OpTest, Div) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1050,10 +1411,26 @@ TEST_F(OpTest, DynamicStitch) { std::vector dims(index_dims[i].begin(), index_dims[i].end()); std::copy(constant_dims.begin(), constant_dims.end(), std::back_inserter(dims)); - Tensor t = RandomTensor(type, dims); - builder.Input(t); + builder.RandomInput(type, dims); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); + }); +} + +TEST_F(OpTest, Elu) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Elu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, EluGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } @@ -1061,50 +1438,51 @@ TEST_F(OpTest, Equal) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Exp) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Exp").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Exp").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, ExpandDims) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor in = RandomTensor(type); + std::vector in_dims = RandomDims(); Tensor dim(DT_INT32, TensorShape()); - std::uniform_int_distribution d(-1 - in.dims(), in.dims()); + std::uniform_int_distribution d(-1 - in_dims.size(), in_dims.size()); dim.scalar()() = d(generator()); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("ExpandDims").Input(in).Input(dim).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ExpandDims") + .RandomInput(type, in_dims) + .Input(dim) + .Attr("T", type)); }); } TEST_F(OpTest, Fill) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor scalar = RandomTensor(type, {}); std::vector dims = RandomDims(); std::vector shape(dims.begin(), dims.end()); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Fill") - .Input(test::AsTensor(shape)) - .Input(scalar) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Fill") + .Input(test::AsTensor(shape)) + .RandomInput(type, {}) + .Attr("T", type)); }); } TEST_F(OpTest, Floor) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Floor") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Floor").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } @@ -1112,10 +1490,10 @@ TEST_F(OpTest, FloorDiv) { Repeatedly([this]() { DataType type = DT_INT32; auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1123,10 +1501,10 @@ TEST_F(OpTest, FloorMod) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1134,10 +1512,10 @@ TEST_F(OpTest, Greater) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1145,18 +1523,10 @@ TEST_F(OpTest, GreaterEqual) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); - }); -} - -TEST_F(OpTest, Reciprocal) { - Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reciprocal") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1164,9 +1534,9 @@ TEST_F(OpTest, L2Loss) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); // TODO(b/31644876): scalars currently crash. - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss") - .Input(RandomTensor(type, RandomDims(1))) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss") + .RandomInput(type, RandomDims(1)) + .Attr("T", type)); }); } @@ -1174,10 +1544,10 @@ TEST_F(OpTest, Less) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1185,10 +1555,10 @@ TEST_F(OpTest, LessEqual) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -1200,10 +1570,10 @@ TEST_F(OpTest, LinSpace) { }; std::uniform_int_distribution distribution(-50, 50); DataType type = Choose({DT_INT32, DT_INT64}); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LinSpace") - .Input(RandomTensor(DT_FLOAT, {})) - .Input(RandomTensor(DT_FLOAT, {})) + .RandomInput(DT_FLOAT, {}) + .RandomInput(DT_FLOAT, {}) .Input(ToScalar(type, distribution(generator()))) .Attr("T", DT_FLOAT) .Attr("Tidx", type)); @@ -1212,62 +1582,62 @@ TEST_F(OpTest, LinSpace) { TEST_F(OpTest, Log) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Log").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Log").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, LogicalAnd) { Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LogicalAnd") - .Input(RandomTensor(DT_BOOL, dims.first)) - .Input(RandomTensor(DT_BOOL, dims.second))); + .RandomInput(DT_BOOL, dims.first) + .RandomInput(DT_BOOL, dims.second)); }); } TEST_F(OpTest, LogicalNot) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("LogicalNot").Input(RandomTensor(DT_BOOL))); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LogicalNot").RandomInput(DT_BOOL)); }); } TEST_F(OpTest, LogicalOr) { Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LogicalOr") - .Input(RandomTensor(DT_BOOL, dims.first)) - .Input(RandomTensor(DT_BOOL, dims.second))); + .RandomInput(DT_BOOL, dims.first) + .RandomInput(DT_BOOL, dims.second)); }); } TEST_F(OpTest, LogSoftmax) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LogSoftmax") - .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2))) + .RandomInput(DT_FLOAT, RandomDims(2, 2)) .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, LRN) { Repeatedly([this]() { - Tensor data; // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed. - data = RandomTensor(DT_FLOAT, RandomDims(4, 4, 1, 8)); + std::vector data_dims = RandomDims(4, 4, 1, 8); // CuDNN requires depth_radius > 0. - std::uniform_int_distribution radius(1, data.dim_size(3)); + std::uniform_int_distribution radius(1, data_dims[3]); std::uniform_real_distribution coeff(0.01, 2.0); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRN") - .Input(data) - .Attr("T", DT_FLOAT) - .Attr("depth_radius", radius(generator())) - .Attr("bias", coeff(generator())) - .Attr("alpha", coeff(generator())) - .Attr("beta", coeff(generator()))); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LRN") + .RandomInput(DT_FLOAT, data_dims) + .Attr("T", DT_FLOAT) + .Attr("depth_radius", radius(generator())) + .Attr("bias", coeff(generator())) + .Attr("alpha", coeff(generator())) + .Attr("beta", coeff(generator()))); }); } @@ -1275,21 +1645,19 @@ TEST_F(OpTest, LRNGrad) { Repeatedly([this]() { // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed. std::vector dims = RandomDims(4, 4, 1, 8); - Tensor input_grads = RandomTensor(DT_FLOAT, dims); - Tensor input_image = RandomTensor(DT_FLOAT, dims); - Tensor output_image = RandomTensor(DT_FLOAT, dims); // CuDNN requires depth_radius > 0. - std::uniform_int_distribution radius(1, input_grads.dim_size(3)); + std::uniform_int_distribution radius(1, dims[3]); std::uniform_real_distribution coeff(0.0, 2.0); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRNGrad") - .Input(input_grads) - .Input(input_image) - .Input(output_image) - .Attr("T", DT_FLOAT) - .Attr("depth_radius", radius(generator())) - .Attr("bias", coeff(generator())) - .Attr("alpha", coeff(generator())) - .Attr("beta", coeff(generator()))); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("LRNGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT) + .Attr("depth_radius", radius(generator())) + .Attr("bias", coeff(generator())) + .Attr("alpha", coeff(generator())) + .Attr("beta", coeff(generator()))); }); } @@ -1299,59 +1667,57 @@ TEST_F(OpTest, MatMul) { int64 y = RandomDim(); int64 z = RandomDim(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .Input(RandomTensor(DT_FLOAT, {x, y})) - .Input(RandomTensor(DT_FLOAT, {y, z})) - .Attr("T", DT_FLOAT)); + std::vector a_dims = {x, y}; + std::vector b_dims = {y, z}; - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .Input(RandomTensor(DT_FLOAT, {y, x})) - .Input(RandomTensor(DT_FLOAT, {y, z})) - .Attr("T", DT_FLOAT) - .Attr("transpose_a", true)); + std::bernoulli_distribution random_bool; + bool transpose_a = random_bool(generator()); + bool transpose_b = random_bool(generator()); + if (transpose_a) { + std::swap(a_dims[0], a_dims[1]); + } + if (transpose_b) { + std::swap(b_dims[0], b_dims[1]); + } - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .Input(RandomTensor(DT_FLOAT, {x, y})) - .Input(RandomTensor(DT_FLOAT, {z, y})) - .Attr("T", DT_FLOAT) - .Attr("transpose_b", true)); - - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .Input(RandomTensor(DT_FLOAT, {y, x})) - .Input(RandomTensor(DT_FLOAT, {z, y})) - .Attr("T", DT_FLOAT) - .Attr("transpose_a", true) - .Attr("transpose_b", true)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") + .RandomInput(DT_FLOAT, a_dims) + .RandomInput(DT_FLOAT, b_dims) + .Attr("T", DT_FLOAT) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b)); }); } TEST_F(OpTest, MatrixDiag) { Repeatedly([this]() { - DataType type = Choose({DT_BOOL, DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag") - .Input(RandomTensor(type, RandomDims(1))) - .Attr("T", type)); + DataType type = Choose({DT_INT32, DT_FLOAT}); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag") + .RandomInput(type, RandomDims(1)) + .Attr("T", type)); }); } TEST_F(OpTest, MatrixDiagPart) { Repeatedly([this]() { - DataType type = Choose({DT_BOOL, DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart") - .Input(RandomTensor(type, RandomDims(2))) - .Attr("T", type)); + DataType type = Choose({DT_INT32, DT_FLOAT}); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart") + .RandomInput(type, RandomDims(2)) + .Attr("T", type)); }); } TEST_F(OpTest, Max) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - Tensor data = RandomTensor(type); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Max").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Max") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } @@ -1359,26 +1725,28 @@ TEST_F(OpTest, Maximum) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, MaxPool) { Repeatedly([this]() { std::uniform_int_distribution random_int(1, 5); - int kernel_rows = random_int(generator()), - kernel_cols = random_int(generator()); + std::vector dims = RandomDims(4, 4, 1); + int kernel_rows = + std::uniform_int_distribution(1, dims[1])(generator()); + int kernel_cols = + std::uniform_int_distribution(1, dims[2])(generator()); int stride_rows = random_int(generator()), stride_cols = random_int(generator()); + string padding = Choose({"SAME", "VALID"}); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("MaxPool") - .Input( - RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows), - RandomDim(kernel_cols), RandomDim(1)})) + .RandomInput(DT_FLOAT, dims) .Attr("T", DT_FLOAT) .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) .Attr("strides", {1, stride_rows, stride_cols, 1}) @@ -1388,29 +1756,66 @@ TEST_F(OpTest, MaxPool) { // TODO(phawkins): test NCHW format (not supported by CPU) } +TEST_F(OpTest, MaxPool3D) { + Repeatedly([this]() { + std::uniform_int_distribution random_int(1, 5); + std::vector dims = RandomDims(5, 5, 1); + + std::vector input_dims, kernel_dims, stride_dims; + kernel_dims.push_back(1); + stride_dims.push_back(1); + for (int i = 0; i < 3; ++i) { + kernel_dims.push_back( + std::uniform_int_distribution(1, dims[i])(generator())); + input_dims.push_back(dims[i]); + stride_dims.push_back(random_int(generator())); + } + kernel_dims.push_back(1); + stride_dims.push_back(1); + int64 batch = dims[3]; + int64 feature = dims[4]; + + string padding = Choose({"SAME", "VALID"}); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("MaxPool3D") + .RandomInput(DT_FLOAT, + ImageDims(FORMAT_NHWC, batch, feature, input_dims)) + .Attr("T", DT_FLOAT) + .Attr("ksize", kernel_dims) + .Attr("strides", stride_dims) + .Attr("padding", padding) + .Attr("data_format", "NDHWC")); + }); + // TODO(phawkins): test NCHW format (not supported by CPU) +} + TEST_F(OpTest, Mean) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); // TODO(phawkins): CPU and XLA differ output for reducing across a // size-0 dimension (nan vs 0). For now, require size >= 1. - Tensor data = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 1)); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(0, kDefaultMaxRank, 1); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Mean").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mean") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } TEST_F(OpTest, Min) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - Tensor data = RandomTensor(type); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Min").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Min") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } @@ -1418,21 +1823,20 @@ TEST_F(OpTest, Minimum) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Mod) { Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Mod") - .Input(RandomTensor(DT_INT32, dims.first)) - .Input(RandomTensor(DT_INT32, dims.second)) - .Attr("T", DT_INT32)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mod") + .RandomInput(DT_INT32, dims.first) + .RandomInput(DT_INT32, dims.second) + .Attr("T", DT_INT32)); }); } @@ -1440,18 +1844,18 @@ TEST_F(OpTest, Mul) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Neg) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Neg").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Neg").RandomInput(type).Attr("T", type)); }); } @@ -1459,10 +1863,48 @@ TEST_F(OpTest, NotEqual) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, OneHot) { + Repeatedly([this]() { + DataType type = Choose(kAllXlaTypes); + + std::vector dims = RandomDims(); + int num_dims = dims.size(); + + int32 depth = RandomDim(); + + Tensor indices(DT_INT32, TensorShape(dims)); + std::uniform_int_distribution distribution(-depth * 2, depth * 2); + test::FillFn(&indices, [this, &distribution](int i) -> int32 { + return distribution(generator()); + }); + + int axis = std::uniform_int_distribution(-num_dims - 5, + num_dims + 5)(generator()); + + OpTestBuilder builder("OneHot"); + builder.Attr("T", type); + builder.Attr("TI", DT_INT32); + builder.Attr("axis", axis); + builder.Input(indices); + builder.Input(test::AsScalar(depth)); + builder.RandomInput(type, {}); + builder.RandomInput(type, {}); + return ExpectTfAndXlaOutputsAreClose(builder); + }); +} + +TEST_F(OpTest, OnesLike) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type)); }); } @@ -1481,9 +1923,9 @@ TEST_F(OpTest, Pack) { builder.Attr("N", n); builder.Attr("axis", axis); for (int i = 0; i < n; ++i) { - builder.Input(RandomTensor(type, dims)); + builder.RandomInput(type, dims); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } @@ -1491,23 +1933,26 @@ TEST_F(OpTest, Pack) { TEST_F(OpTest, Pad) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor t = RandomTensor(type); + std::vector t_dims = RandomDims(); // TODO(b/31741996): re-enable DT_INT64 when bug is fixed. // DataType tpaddings = Choose({DT_INT32, DT_INT64}); DataType tpaddings = DT_INT32; std::vector paddings_vec; std::uniform_int_distribution distribution(0, 7); - for (int i = 0; i < t.dims(); ++i) { + for (int i = 0; i < t_dims.size(); ++i) { paddings_vec.push_back(distribution(generator())); paddings_vec.push_back(distribution(generator())); } Tensor paddings; - CHECK(paddings.CopyFrom(AsIntTensor(tpaddings, paddings_vec), - TensorShape({t.dims(), 2}))); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Pad").Input(t).Input(paddings).Attr("T", type).Attr( - "Tpaddings", tpaddings)); + CHECK( + paddings.CopyFrom(AsIntTensor(tpaddings, paddings_vec), + TensorShape({static_cast(t_dims.size()), 2}))); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pad") + .RandomInput(type, t_dims) + .Input(paddings) + .Attr("T", type) + .Attr("Tpaddings", tpaddings)); }); } @@ -1516,23 +1961,24 @@ TEST_F(OpTest, Pow) { // nontermination. Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Pow") - .Input(RandomTensor(DT_FLOAT, dims.first)) - .Input(RandomTensor(DT_FLOAT, dims.second)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow") + .RandomInput(DT_FLOAT, dims.first) + .RandomInput(DT_FLOAT, dims.second) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Prod) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - Tensor data = RandomTensor(type); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Prod").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Prod") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } @@ -1547,7 +1993,7 @@ TEST_F(OpTest, Range) { }; std::uniform_int_distribution distribution(-50, 50); DataType tidx = Choose({DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Range") .Input(ToScalar(tidx, distribution(generator()))) .Input(ToScalar(tidx, distribution(generator()))) @@ -1559,8 +2005,8 @@ TEST_F(OpTest, Range) { TEST_F(OpTest, Rank) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Rank").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Rank").RandomInput(type).Attr("T", type)); }); } @@ -1568,46 +2014,51 @@ TEST_F(OpTest, RealDiv) { Repeatedly([this]() { DataType type = DT_FLOAT; auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, Reciprocal) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Reciprocal").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Relu) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Relu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Relu6) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Relu6").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Relu6Grad) { Repeatedly([this]() { auto dims = RandomDims(1); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, ReluGrad) { Repeatedly([this]() { auto dims = RandomDims(1); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } @@ -1629,39 +2080,68 @@ TEST_F(OpTest, Reshape) { } } } - Tensor data = RandomTensor(type, dims_before); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Reshape") - .Input(data) + .RandomInput(type, dims_before) .Input(test::AsTensor( std::vector(dims_after.begin(), dims_after.end()))) .Attr("T", type)); }); } +TEST_F(OpTest, Reverse) { + Repeatedly([this]() { + std::vector dims = RandomDims(1); + DataType type = Choose({DT_INT32, DT_FLOAT}); + int64 rank = dims.size(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse") + .RandomInput(type, dims) + .RandomInput(DT_BOOL, {rank}) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, ReverseV2) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Round) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Round").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Rsqrt) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Rsqrt") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Rsqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, RsqrtGrad) { Repeatedly([this]() { auto dims = RandomDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Shape) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Shape").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Shape").RandomInput(type).Attr("T", type)); }); } @@ -1673,76 +2153,235 @@ TEST_F(OpTest, ShapeN) { builder.Attr("T", type); builder.Attr("N", n); for (int i = 0; i < n; ++i) { - builder.Input(RandomTensor(type)); + builder.RandomInput(type); } - ExpectTfAndXlaOutputsAreClose(builder); + return ExpectTfAndXlaOutputsAreClose(builder); }); } TEST_F(OpTest, Sigmoid) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sigmoid") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sigmoid").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, SigmoidGrad) { Repeatedly([this]() { auto dims = RandomDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Sign) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sign").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sign").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Size) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Size").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Size").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Slice) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor data = RandomTensor(type); + std::vector data_dims = RandomDims(); - std::vector begin(data.dims()), size(data.dims()); - for (int i = 0; i < data.dims(); ++i) { - begin[i] = std::uniform_int_distribution( - 0, data.dim_size(i))(generator()); + std::vector begin(data_dims.size()), size(data_dims.size()); + for (int i = 0; i < data_dims.size(); ++i) { + begin[i] = + std::uniform_int_distribution(0, data_dims[i])(generator()); size[i] = std::uniform_int_distribution( - -1, data.dim_size(i) - begin[i])(generator()); + -1, data_dims[i] - begin[i])(generator()); } - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Slice") - .Input(data) - .Input(test::AsTensor(begin)) - .Input(test::AsTensor(size)) - .Attr("T", type) - .Attr("Index", DT_INT32)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Slice") + .RandomInput(type, data_dims) + .Input(test::AsTensor(begin)) + .Input(test::AsTensor(size)) + .Attr("T", type) + .Attr("Index", DT_INT32)); }); } TEST_F(OpTest, Softmax) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Softmax") - .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2))) + .RandomInput(DT_FLOAT, RandomDims(2, 2)) .Attr("T", DT_FLOAT)); }); } +TEST_F(OpTest, SoftmaxCrossEntropyWithLogits) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, 2, 1); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("SoftmaxCrossEntropyWithLogits") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Softplus) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Softplus").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SoftplusGrad) { + Repeatedly([this]() { + std::vector dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SpaceToBatch) { + Repeatedly([this]() { + std::vector block_dims = RandomDims(4, 4, 0, 5); + const int num_block_dims = 2; + int64 block_size = RandomDim(0, 4); + + std::vector input_dims(1 + num_block_dims + 1); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[1 + i] = block_dims[i] * block_size; + } + input_dims[1 + num_block_dims] = RandomDim(); + + std::vector padding_vals; + std::uniform_int_distribution distribution(0, 7); + for (int i = 0; i < num_block_dims; ++i) { + int64 pad_before; + int64 pad_after; + do { + pad_before = distribution(generator()); + pad_after = distribution(generator()); + } while (pad_before + pad_after > input_dims[1 + i]); + input_dims[1 + i] -= pad_before + pad_after; + padding_vals.push_back(pad_before); + padding_vals.push_back(pad_after); + } + Tensor paddings; + CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), + TensorShape({num_block_dims, 2}))); + + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch") + .RandomInput(DT_FLOAT, input_dims) + .Input(paddings) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); + }); +} + +TEST_F(OpTest, SpaceToBatchND) { + Repeatedly([this]() { + std::vector block_dims = RandomDims(1, 3, 0, 5); + int num_block_dims = block_dims.size(); + std::vector remaining_dims = RandomDims(0, 3); + std::vector block_multipliers = + RandomDims(block_dims.size(), block_dims.size(), 0, 4); + + std::vector input_dims(1 + num_block_dims + remaining_dims.size()); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[1 + i] = block_dims[i] * block_multipliers[i]; + } + std::copy(remaining_dims.begin(), remaining_dims.end(), + input_dims.begin() + 1 + num_block_dims); + + std::vector padding_vals; + std::uniform_int_distribution distribution(0, 7); + for (int i = 0; i < num_block_dims; ++i) { + int64 pad_before; + int64 pad_after; + do { + pad_before = distribution(generator()); + pad_after = distribution(generator()); + } while (pad_before + pad_after > input_dims[1 + i]); + input_dims[1 + i] -= pad_before + pad_after; + padding_vals.push_back(pad_before); + padding_vals.push_back(pad_after); + } + Tensor paddings; + CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), + TensorShape({num_block_dims, 2}))); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("SpaceToBatchND") + .RandomInput(DT_FLOAT, input_dims) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) + .Input(paddings) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SparseMatMul) { + Repeatedly([this]() { + int64 x = RandomDim(); + int64 y = RandomDim(); + int64 z = RandomDim(); + + std::vector a_dims = {x, y}; + std::vector b_dims = {y, z}; + + std::bernoulli_distribution random_bool; + bool transpose_a = random_bool(generator()); + bool transpose_b = random_bool(generator()); + if (transpose_a) { + std::swap(a_dims[0], a_dims[1]); + } + if (transpose_b) { + std::swap(b_dims[0], b_dims[1]); + } + + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") + .RandomInput(DT_FLOAT, a_dims) + .RandomInput(DT_FLOAT, b_dims) + .Attr("Ta", DT_FLOAT) + .Attr("Tb", DT_FLOAT) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b)); + }); +} + +TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, 2, 1); + int64 batch_size = dims[0]; + int64 num_classes = dims[1]; + + std::vector indices(batch_size); + for (int64 i = 0; i < batch_size; ++i) { + indices[i] = + std::uniform_int_distribution(0, num_classes - 1)(generator()); + } + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("SparseSoftmaxCrossEntropyWithLogits") + .RandomInput(DT_FLOAT, dims) + .Input(test::AsTensor(indices)) + .Attr("T", DT_FLOAT) + .Attr("Tlabels", DT_INT32)); + }); +} + TEST_F(OpTest, Split) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); @@ -1754,110 +2393,54 @@ TEST_F(OpTest, Split) { // Ensure 'dim' is evenly divisible by 'n'. dims[dim] /= n; dims[dim] *= n; - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split") - .Input(test::AsScalar(dim)) - .Input(RandomTensor(type, dims)) - .Attr("T", type) - .Attr("num_split", n)); - }); -} - -TEST_F(OpTest, Softplus) { - Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Softplus") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); - }); -} - -TEST_F(OpTest, SoftplusGrad) { - Repeatedly([this]() { - std::vector dims = RandomDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); - }); -} - -TEST_F(OpTest, SparseMatMul) { - Repeatedly([this]() { - int64 x = RandomDim(); - int64 y = RandomDim(); - int64 z = RandomDim(); - - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") - .Input(RandomTensor(DT_FLOAT, {x, y})) - .Input(RandomTensor(DT_FLOAT, {y, z})) - .Attr("Ta", DT_FLOAT) - .Attr("Tb", DT_FLOAT)); - - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") - .Input(RandomTensor(DT_FLOAT, {y, x})) - .Input(RandomTensor(DT_FLOAT, {y, z})) - .Attr("Ta", DT_FLOAT) - .Attr("Tb", DT_FLOAT) - .Attr("transpose_a", true)); - - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") - .Input(RandomTensor(DT_FLOAT, {x, y})) - .Input(RandomTensor(DT_FLOAT, {z, y})) - .Attr("Ta", DT_FLOAT) - .Attr("Tb", DT_FLOAT) - .Attr("transpose_b", true)); - - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul") - .Input(RandomTensor(DT_FLOAT, {y, x})) - .Input(RandomTensor(DT_FLOAT, {z, y})) - .Attr("Ta", DT_FLOAT) - .Attr("Tb", DT_FLOAT) - .Attr("transpose_a", true) - .Attr("transpose_b", true)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split") + .Input(test::AsScalar(dim)) + .RandomInput(type, dims) + .Attr("T", type) + .Attr("num_split", n)); }); } TEST_F(OpTest, Sqrt) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sqrt") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, SquaredDifference) { Repeatedly([this]() { auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("SquaredDifference") - .Input(RandomTensor(DT_FLOAT, dims.first)) - .Input(RandomTensor(DT_FLOAT, dims.second)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SquaredDifference") + .RandomInput(DT_FLOAT, dims.first) + .RandomInput(DT_FLOAT, dims.second) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Square) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Square").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Square").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Squeeze) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor t = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 0, 5)); + std::vector t_dims = RandomDims(0, kDefaultMaxRank, 0, 5); std::bernoulli_distribution random_bool; std::vector squeeze_dims; - for (int i = 0; i < t.dims(); ++i) { - if (t.dim_size(i) == 1 && random_bool(generator())) { + for (int i = 0; i < t_dims.size(); ++i) { + if (t_dims[i] == 1 && random_bool(generator())) { squeeze_dims.push_back(i); } } - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze") - .Input(t) - .Attr("squeeze_dims", squeeze_dims) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze") + .RandomInput(type, t_dims) + .Attr("squeeze_dims", squeeze_dims) + .Attr("T", type)); }); } @@ -1865,58 +2448,59 @@ TEST_F(OpTest, Sub) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Sum) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - Tensor data = RandomTensor(type); - Tensor indices = RandomReductionIndices(data.dims()); + std::vector data_dims = RandomDims(); + Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sum").Input(data).Input(indices).Attr("T", type).Attr( - "keep_dims", keep_dims)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sum") + .RandomInput(type, data_dims) + .Input(indices) + .Attr("T", type) + .Attr("keep_dims", keep_dims)); }); } TEST_F(OpTest, StridedSlice) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor data = RandomTensor(type); - - std::vector begin(data.dims()), end(data.dims()); - std::vector strides(data.dims()); - for (int i = 0; i < data.dims(); ++i) { + std::vector data_dims = RandomDims(); + std::vector begin(data_dims.size()), end(data_dims.size()); + std::vector strides(data_dims.size()); + for (int i = 0; i < data_dims.size(); ++i) { begin[i] = std::uniform_int_distribution( - -2 * data.dim_size(i), 2 * data.dim_size(i))(generator()); + -2 * data_dims[i], 2 * data_dims[i])(generator()); end[i] = std::uniform_int_distribution( - -2 * data.dim_size(i), 2 * data.dim_size(i))(generator()); + -2 * data_dims[i], 2 * data_dims[i])(generator()); // TODO(b/31360685): support strides other than 1 or -1 strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1; } - int64 max_bitmask = (1LL << data.dims()) - 1; + int64 max_bitmask = (1LL << data_dims.size()) - 1; std::uniform_int_distribution bitmask_distribution(0, max_bitmask); int64 begin_mask = bitmask_distribution(generator()); int64 end_mask = bitmask_distribution(generator()); // Create a ellipsis bitmask with at most one 1 bit set. int64 ellipsis_mask = 0; - if (data.dims() > 0 && std::bernoulli_distribution()(generator())) { - int ellipsis_pos = - std::uniform_int_distribution(0, data.dims() - 1)(generator()); + if (!data_dims.empty() && std::bernoulli_distribution()(generator())) { + int ellipsis_pos = std::uniform_int_distribution( + 0, data_dims.size() - 1)(generator()); ellipsis_mask = 1LL << ellipsis_pos; } int64 new_axis_mask = bitmask_distribution(generator()); int64 shrink_axis_mask = bitmask_distribution(generator()); - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("StridedSlice") - .Input(data) + .RandomInput(type, data_dims) .Input(test::AsTensor(begin)) .Input(test::AsTensor(end)) .Input(test::AsTensor(strides)) @@ -1966,13 +2550,13 @@ TEST_F(OpTest, StridedSliceGrad) { // TODO(phawkins): use shape inference for the forward op to compute the // gradient shape for the backward op. At present, there is a low // probability of the golden op succeeding. - ExpectTfAndXlaOutputsAreClose( + return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("StridedSliceGrad") .Input(test::AsTensor(dims)) .Input(test::AsTensor(begin)) .Input(test::AsTensor(end)) .Input(test::AsTensor(strides)) - .Input(RandomTensor(type, RandomDims(1))) + .RandomInput(type, RandomDims(1)) .Attr("T", type) .Attr("Index", DT_INT64) .Attr("begin_mask", begin_mask) @@ -1985,48 +2569,48 @@ TEST_F(OpTest, StridedSliceGrad) { TEST_F(OpTest, Tanh) { Repeatedly([this]() { - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tanh") - .Input(RandomTensor(DT_FLOAT)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Tanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, TanhGrad) { Repeatedly([this]() { auto dims = RandomDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad") - .Input(RandomTensor(DT_FLOAT, dims)) - .Input(RandomTensor(DT_FLOAT, dims)) - .Attr("T", DT_FLOAT)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); }); } TEST_F(OpTest, Tile) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor t = RandomTensor(type, RandomDims(1)); - std::vector multiples(t.dims()); - for (int i = 0; i < t.dims(); ++i) { + std::vector t_dims = RandomDims(1); + std::vector multiples(t_dims.size()); + for (int i = 0; i < t_dims.size(); ++i) { multiples[i] = std::uniform_int_distribution(1, 3)(generator()); } - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tile") - .Input(t) - .Input(test::AsTensor(multiples)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Tile") + .RandomInput(type, t_dims) + .Input(test::AsTensor(multiples)) + .Attr("T", type)); }); } TEST_F(OpTest, Transpose) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); - Tensor data = RandomTensor(type); - std::vector perm(data.dims()); + std::vector data_dims = RandomDims(); + std::vector perm(data_dims.size()); std::iota(perm.begin(), perm.end(), 0); std::shuffle(perm.begin(), perm.end(), generator()); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose") - .Input(data) - .Input(test::AsTensor(perm)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose") + .RandomInput(type, data_dims) + .Input(test::AsTensor(perm)) + .Attr("T", type)); }); } @@ -2034,10 +2618,10 @@ TEST_F(OpTest, TruncateDiv) { Repeatedly([this]() { DataType type = DT_INT32; auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } @@ -2045,18 +2629,18 @@ TEST_F(OpTest, TruncateMod) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); - ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod") - .Input(RandomTensor(type, dims.first)) - .Input(RandomTensor(type, dims.second)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, ZerosLike) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); - ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("ZerosLike").Input(RandomTensor(type)).Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type)); }); } @@ -2075,6 +2659,9 @@ int main(int argc, char** argv) { tensorflow::Flag("tf_xla_test_repetitions", &tensorflow::tf_xla_test_repetitions, "Number of repetitions for each test."), + tensorflow::Flag("tf_xla_max_tensor_size", + &tensorflow::tf_xla_max_tensor_size, + "Maximum number of elements for random input tensors."), tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr, "Tensorflow device type to use for test"), tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit, diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py new file mode 100644 index 00000000000..18fabca28c9 --- /dev/null +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -0,0 +1,65 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for XLA Reverse Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class ReverseOpsTest(XLATestCase): + + def testReverseOneDim(self): + shape = (7, 5, 9, 11) + for revdim in range(len(shape)): + self._AssertReverseEqual([revdim], shape) + + def testReverseMoreThanOneDim(self): + shape = (7, 5, 9, 11) + for revdims in itertools.chain.from_iterable( + itertools.combinations(range(len(shape)), k) + for k in range(2, len(shape)+1)): + self._AssertReverseEqual(revdims, shape) + + def _AssertReverseEqual(self, revdims, shape): + np.random.seed(120) + pval = np.random.randint(0, 100, size=shape).astype(float) + with self.test_session(): + with self.test_scope(): + p = array_ops.placeholder(dtypes.int32, shape=shape) + axis = constant_op.constant( + np.array(revdims, dtype=np.int32), + shape=(len(revdims),), dtype=dtypes.int32) + rval = array_ops.reverse(p, axis).eval({p: pval}) + + slices = [ + slice(-1, None, -1) if d in revdims else slice(None) + for d in range(len(shape))] + self.assertEqual( + pval[slices].flatten().tolist(), + rval.flatten().tolist()) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py new file mode 100644 index 00000000000..ecdce4f052b --- /dev/null +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for RMSProp optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import rmsprop + + +class RmspropTest(XLATestCase): + + def testBasic(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + rms_opt = rmsprop.RMSPropOptimizer(3.0) + rms_update = rms_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of RMSProp + for _ in range(3): + rms_update.run() + + # Validate updated params + self.assertAllCloseAccordingToType( + np.array([2.91705132e-04, 1.00029182e+00]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([2.89990854, 3.89990854]), var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py new file mode 100644 index 00000000000..4ddf2ee0dcb --- /dev/null +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -0,0 +1,145 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slicing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + + +class SliceTest(XLATestCase): + + def test1D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.slice(i, [2], [4]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([2, 3, 4, 5], result) + + def test3D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + with self.test_scope(): + o = array_ops.slice(i, [1, 2, 2], [1, 1, 4]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[6, 5, 4, 3]]], result) + + + +class StridedSliceTest(XLATestCase): + + def test1D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [2], [6], [2]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([2, 4], result) + + def test1DNegtiveStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [6], [2], [-2]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([6, 4], result) + + def test3D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[1, 9]], [[6, 4]]], result) + + def test3DNegativeStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 4, 10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0], + [4, 5, 2, 4, 3, 7, 6, 8, 9, 4]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [4, 3, 4, 5, 7, 6, 5, 3, 4, 5], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7], + [7, 1, 7, 1, 8, 1, 8, 1, 3, 1]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9], + [9, 9, 5, 5, 6, 6, 3, 3, 6, 6]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[9, 8], + [1, 1]], + [[2, 4], + [5, 7]]], result) + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py new file mode 100644 index 00000000000..9c3b86c84b2 --- /dev/null +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -0,0 +1,266 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for SpaceToBatch and BatchToSpace ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import test + + +def space_to_batch_direct(input_array, block_shape, paddings): + """Direct Python implementation of space-to-batch conversion. + + This is used for tests only. + + Args: + input_array: N-D array + block_shape: 1-D array of shape [num_block_dims]. + paddings: 2-D array of shape [num_block_dims, 2]. + + Returns: + Converted tensor. + """ + input_array = np.array(input_array) + block_shape = np.array(block_shape) + num_block_dims = len(block_shape) + paddings = np.array(paddings).reshape((len(block_shape), 2)) + + padded = np.pad(input_array, + pad_width=([[0, 0]] + list(paddings) + [[0, 0]] * + (input_array.ndim - 1 - num_block_dims)), + mode="constant") + reshaped_padded_shape = [input_array.shape[0]] + output_shape = [input_array.shape[0] * np.prod(block_shape)] + for block_dim, block_shape_value in enumerate(block_shape): + reduced_size = padded.shape[block_dim + 1] // block_shape_value + reshaped_padded_shape.append(reduced_size) + output_shape.append(reduced_size) + reshaped_padded_shape.append(block_shape_value) + reshaped_padded_shape.extend(input_array.shape[num_block_dims + 1:]) + output_shape.extend(input_array.shape[num_block_dims + 1:]) + + reshaped_padded = padded.reshape(reshaped_padded_shape) + permuted_reshaped_padded = np.transpose(reshaped_padded, ( + list(np.arange(num_block_dims) * 2 + 2) + [0] + + list(np.arange(num_block_dims) * 2 + 1) + list( + np.arange(input_array.ndim - num_block_dims - 1) + 1 + num_block_dims + * 2))) + return permuted_reshaped_padded.reshape(output_shape) + + +class SpaceToBatchTest(XLATestCase): + """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" + + def _testPad(self, inputs, paddings, block_size, outputs): + with self.test_session() as sess, self.test_scope(): + for dtype in self.float_types: + # outputs = space_to_batch(inputs) + placeholder = array_ops.placeholder(dtype) + x_tf = gen_array_ops._space_to_batch( + placeholder, paddings, block_size=block_size) + self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + # inputs = batch_to_space(outputs) + x_tf = gen_array_ops._batch_to_space( + placeholder, paddings, block_size=block_size) + self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + + def _testOne(self, inputs, block_size, outputs): + paddings = np.zeros((2, 2), dtype=np.int32) + self._testPad(inputs, paddings, block_size, outputs) + + # [1, 2, 2, 1] <-> [4, 1, 1, 1] + def testSmallInput2x2(self): + x_np = [[[[1], [2]], [[3], [4]]]] + block_size = 2 + x_out = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + self._testOne(x_np, block_size, x_out) + + # [1, 2, 2, 1] <-> [1, 3, 3, 1] (padding) <-> [9, 1, 1, 1] + def testSmallInput2x2Pad1x0(self): + x_np = [[[[1], [2]], [[3], [4]]]] + paddings = np.array([[1, 0], [1, 0]], dtype=np.int32) + block_size = 3 + x_out = [[[[0]]], [[[0]]], [[[0]]], [[[0]]], [[[1]]], [[[2]]], [[[0]]], + [[[3]]], [[[4]]]] + self._testPad(x_np, paddings, block_size, x_out) + + # Test with depth larger than 1. + # [1, 2, 2, 3] <-> [4, 1, 1, 3] + def testDepthInput2x2(self): + x_np = [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]] + block_size = 2 + x_out = [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] + self._testOne(x_np, block_size, x_out) + + # Test for larger input dimensions. + # [1, 4, 4, 1] <-> [4, 2, 2, 1] + def testLargerInput2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]], + [[9], [10], [11], [12]], [[13], [14], [15], [16]]]] + block_size = 2 + x_out = [[[[1], [3]], [[9], [11]]], [[[2], [4]], [[10], [12]]], + [[[5], [7]], [[13], [15]]], [[[6], [8]], [[14], [16]]]] + self._testOne(x_np, block_size, x_out) + + # Test with batch larger than 1. + # [2, 2, 4, 1] <-> [8, 1, 2, 1] + def testBatchInput2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]]], + [[[9], [10], [11], [12]], [[13], [14], [15], [16]]]] + block_size = 2 + x_out = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], + [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] + self._testOne(x_np, block_size, x_out) + + # Tests for larger input spatial dimensions AND batch larger than 1, to ensure + # that elements are correctly laid out spatially and properly interleaved + # along the batch dimension. + # [2, 4, 4, 1] <-> [8, 2, 2, 1] + def testLargerInputBatch2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]], + [[9], [10], [11], [12]], [[13], [14], [15], [16]]], + [[[17], [18], [19], [20]], [[21], [22], [23], [24]], + [[25], [26], [27], [28]], [[29], [30], [31], [32]]]] + x_out = [[[[1], [3]], [[9], [11]]], [[[17], [19]], [[25], [27]]], + [[[2], [4]], [[10], [12]]], [[[18], [20]], [[26], [28]]], + [[[5], [7]], [[13], [15]]], [[[21], [23]], [[29], [31]]], + [[[6], [8]], [[14], [16]]], [[[22], [24]], [[30], [32]]]] + block_size = 2 + self._testOne(x_np, block_size, x_out) + + +class SpaceToBatchNDTest(XLATestCase): + """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops.""" + + def _testPad(self, inputs, block_shape, paddings, outputs): + block_shape = np.array(block_shape) + paddings = np.array(paddings).reshape((len(block_shape), 2)) + with self.test_session() as sess, self.test_scope(): + for dtype in self.float_types: + placeholder = array_ops.placeholder(dtype) + # outputs = space_to_batch(inputs) + x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings) + self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + # inputs = batch_to_space(outputs) + placeholder = array_ops.placeholder(dtype) + x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, paddings) + self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + + def _testDirect(self, input_shape, block_shape, paddings): + inputs = np.arange(np.prod(input_shape), dtype=np.float32) + inputs = inputs.reshape(input_shape) + self._testPad(inputs, block_shape, paddings, + space_to_batch_direct(inputs, block_shape, paddings)) + + def testZeroBlockDimsZeroRemainingDims(self): + self._testPad( + inputs=[1, 2], + block_shape=[], + paddings=[], + outputs=[1, 2],) + + def testZeroBlockDimsOneRemainingDim(self): + self._testPad( + inputs=[[1, 2], [3, 4]], + block_shape=[], + paddings=[], + outputs=[[1, 2], [3, 4]]) + + # Same thing, but with a no-op block dim. + self._testPad( + inputs=[[1, 2], [3, 4]], + block_shape=[1], + paddings=[[0, 0]], + outputs=[[1, 2], [3, 4]]) + + def testZeroBlockDimsTwoRemainingDims(self): + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[], + paddings=[], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + # Same thing, but with a no-op block dim. + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[1], + paddings=[[0, 0]], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + # Same thing, but with two no-op block dims. + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[1, 1], + paddings=[[0, 0], [0, 0]], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + def testOneBlockDimZeroRemainingDims(self): + self._testPad( + inputs=[[1, 2, 3], [4, 5, 6]], + block_shape=[2], + paddings=[1, 0], + outputs=[[0, 2], [0, 5], [1, 3], [4, 6]]) + + def testOneBlockDimOneRemainingDim(self): + self._testPad( + inputs=[[[1, 11], [2, 21], [3, 31]], [[4, 41], [5, 51], [6, 61]]], + block_shape=[2], + paddings=[1, 0], + outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]], + [[4, 41], [6, 61]]]) + + def testDirect(self): + # Test with zero-size remaining dimension. + self._testDirect( + input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]]) + + # Test with zero-size blocked dimension. + self._testDirect( + input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]]) + + # Test with padding up from zero size. + self._testDirect( + input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]]) + + self._testDirect( + input_shape=[3, 3, 4, 5, 2], + block_shape=[3, 4, 2], + paddings=[[1, 2], [0, 0], [3, 0]]) + + self._testDirect( + input_shape=[3, 3, 4, 5, 2], + block_shape=[3, 4, 2, 2], + paddings=[[1, 2], [0, 0], [3, 0], [0, 0]]) + + self._testDirect( + input_shape=[3, 2, 2, 3, 4, 5, 2, 5], + block_shape=[1, 1, 3, 4, 2, 2], + paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]]) + + self._testDirect( + input_shape=[3, 2, 2, 3, 4, 5, 2, 5], + block_shape=[1, 1, 3, 4, 2, 2, 1], + paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0], [0, 0]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py new file mode 100644 index 00000000000..27a29773053 --- /dev/null +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -0,0 +1,1018 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for XLA TensorArray Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def _make_converter(dtype): + def _converter(x): + return np.asarray(x).astype(dtype.as_numpy_dtype) + return _converter + + +class TensorArrayTest(xla_test.XLATestCase): + + def testTensorArrayWriteRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3) + + w0 = ta.write(0, [[4.0, 5.0]]) + w1 = w0.write(1, [[1.0, 3.0]]) + w2 = w1.write(2, [[7.0, -8.5]]) + + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual([[4.0, 5.0]], d0) + self.assertAllEqual([[1.0, 3.0]], d1) + self.assertAllEqual([[7.0, -8.5]], d2) + + def _testTensorArrayWritePack(self, tf_dtype): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + convert = _make_converter(tf_dtype) + + w0 = ta.write(0, convert([[4.0, 5.0]])) + w1 = w0.write(1, convert([[6.0, 7.0]])) + w2 = w1.write(2, convert([[8.0, 9.0]])) + + c0 = w2.stack() + + self.assertAllEqual( + convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval()) + + def testTensorArrayWritePack(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayWritePack(dtype) + + def testEmptyTensorArrayPack(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + + empty_element = np.zeros((0, 1), dtype=np.float32) + w0 = ta.write(0, empty_element) + w1 = w0.write(1, empty_element) + w2 = w1.write(2, empty_element) + + c0 = w2.stack() + + self.assertAllEqual([3, 0, 1], c0.eval().shape) + + def _testTensorArrayWriteConcat(self, tf_dtype): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + convert = _make_converter(tf_dtype) + + w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0]])) + w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]])) + w2 = w1.write(2, convert([[8.0, 9.0], [204.0, 205.0]])) + + c0 = w2.concat() + + self.assertAllEqual( + convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], + [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), c0.eval()) + + def testTensorArrayWriteConcat(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayWriteConcat(dtype) + + def _testTensorArrayUnpackRead(self, tf_dtype): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + convert = _make_converter(tf_dtype) + + # Unpack a vector into scalars + w0 = ta.unstack(convert([1.0, 2.0, 3.0])) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert(1.0), d0) + self.assertAllEqual(convert(2.0), d1) + self.assertAllEqual(convert(3.0), d2) + + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + # Unpack a matrix into vectors + w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])) + r0 = w1.read(0) + r1 = w1.read(1) + r2 = w1.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([1.0, 1.1]), d0) + self.assertAllEqual(convert([2.0, 2.1]), d1) + self.assertAllEqual(convert([3.0, 3.1]), d2) + + # Reset ta because we're going to change the shape, else shape + # inference will throw an error. + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + # Try unpacking an empty matrix, which should not cause an error. + w2 = ta.unstack(convert([[], [], []])) + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([]), d0) + self.assertAllEqual(convert([]), d1) + self.assertAllEqual(convert([]), d2) + + def _testTensorArrayUnpackReadMaybeLegacy(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayUnpackRead(dtype) + + def testTensorArrayUnpackRead(self): + self._testTensorArrayUnpackReadMaybeLegacy() + + def _testTensorArraySplitRead(self, tf_dtype): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + convert = _make_converter(tf_dtype) + + # Split an empty vector + lengths = constant_op.constant([0, 0, 0]) + w0 = ta.split(convert([]), lengths=lengths) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([]), d0) + self.assertAllEqual(convert([]), d1) + self.assertAllEqual(convert([]), d2) + + # Split a vector + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + lengths = constant_op.constant([1, 1, 1]) + w0 = ta.split(convert([1.0, 2.0, 3.0]), lengths=lengths) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([1.0]), d0) + self.assertAllEqual(convert([2.0]), d1) + self.assertAllEqual(convert([3.0]), d2) + + # Split a matrix + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + lengths = constant_op.constant([1, 1, 1]) + w0 = ta.split( + convert([[1.0, 101.0], [2.0, 201.0], [3.0, 301.0]]), lengths=lengths) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + + d0, d1, d2 = session.run([r0, r1, r2]) + self.assertAllEqual(convert([[1.0, 101.0]]), d0) + self.assertAllEqual(convert([[2.0, 201.0]]), d1) + self.assertAllEqual(convert([[3.0, 301.0]]), d2) + + def testTensorArraySplitRead(self): + for dtype in self.numeric_tf_types: + self._testTensorArraySplitRead(dtype) + + def testTensorGradArrayWriteRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3) + + w0 = ta.write(0, [[4.0]]) + w1 = w0.write(1, [[1.0]]) + w2 = w1.write(2, [[-3.0]]) + + g_ta = w2.grad("grad") + + g_w0 = g_ta.write(0, [[5.0]]) + g_w1 = g_w0.write(1, [[2.0]]) + g_w2 = g_w1.write(2, [[-2.0]]) + + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + + g_r0 = g_w2.read(0) + g_r1 = g_w2.read(1) + g_r2 = g_w2.read(2) + + d0, d1, d2, g_d0, g_d1, g_d2 = session.run([r0, r1, r2, g_r0, g_r1, g_r2]) + self.assertAllEqual([[4.0]], d0) + self.assertAllEqual([[1.0]], d1) + self.assertAllEqual([[-3.0]], d2) + self.assertAllEqual([[5.0]], g_d0) + self.assertAllEqual([[2.0]], g_d1) + self.assertAllEqual([[-2.0]], g_d2) + + def testTensorGradArrayDynamicWriteRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3) + + w0 = ta.write(0, [[4.0]]) + w1 = w0.write(1, [[1.0]]) + w2 = w1.write(2, [[-3.0]]) + + g_ta = w2.grad("grad") # Get gradient array here so we know the shape + + s = w2.size() + g_s = g_ta.size() + + g_w0 = g_ta.write(0, [[5.0]]) + g_w1 = g_w0.write(1, [[2.0]]) + g_w2 = g_w1.write(2, [[-2.0]]) + + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + + g_r0 = g_w2.read(0) + g_r1 = g_w2.read(1) + g_r2 = g_w2.read(2) + + d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = session.run( + [r0, r1, r2, g_r0, g_r1, g_r2, s, g_s]) + self.assertAllEqual([[4.0]], d0) + self.assertAllEqual([[1.0]], d1) + self.assertAllEqual([[-3.0]], d2) + self.assertAllEqual([[5.0]], g_d0) + self.assertAllEqual([[2.0]], g_d1) + self.assertAllEqual([[-2.0]], g_d2) + self.assertAllEqual(3, vs) + self.assertAllEqual(3, g_vs) + + def testTensorGradAccessTwiceReceiveSameObject(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3, + element_shape=[1, 2]) + g_ta_0 = ta.grad("grad") + g_ta_1 = ta.grad("grad") + + with ops.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]): + # Write with one gradient handle, read with another copy of it + r1_0 = g_ta_1.read(0) + + t_g_ta_0, t_g_ta_1, d_r1_0 = session.run( + [g_ta_0.handle.op, g_ta_1.handle.op, r1_0]) + self.assertAllEqual(t_g_ta_0, t_g_ta_1) + self.assertAllEqual([[4.0, 5.0]], d_r1_0) + + def testTensorArrayWriteWrongIndexOrDataTypeFails(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + + # Test writing the wrong datatype + with self.assertRaisesOpError( + "TensorArray dtype is float but op has dtype int32"): + ta.write(-1, np.int32(7)).flow.eval() + + def testTensorArrayReadWrongIndexOrDataTypeFails(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + + w0 = ta.write(0, [[4.0, 5.0]]) + + # Test reading wrong datatype + r0_bad = gen_data_flow_ops._tensor_array_read_v3( + handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) + with self.assertRaisesOpError( + "TensorArray dtype is float but Op requested dtype double."): + r0_bad.eval() + + # Test reading from a different index than the one we wrote to + w0.read(1) + + def testTensorArraySplitIncompatibleShapesFails(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + + with self.assertRaisesOpError( + r"value is not 1D"): + lengths = array_ops.placeholder(dtypes.int64) + ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1}) + + with self.assertRaisesOpError( + r"lengths must be equal: 1 vs. 2"): + ta.split([1.0, 2.0, 3.0], [1, 2, 3]).flow.eval() + + with self.assertRaisesOpError( + r"value must have rank >= 1"): + ta.split(1.0, [1]).flow.eval() + + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + infer_shape=False) + + with self.assertRaisesOpError( + r"TensorArray's size is not equal to the size of lengths " + r"\(1 vs. 2\)"): + ta.split([1.0], [1]).flow.eval() + + def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False) + + c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype) + + w0 = ta.write(2, c(3.0)) + w1 = w0.write(2, c(4.0)) + + ta_grad = w1.grad("grad") + + w0_grad = ta_grad.write(2, c(3.0)) + w1_grad = w0_grad.write(2, c(4.0)) + w2_grad = w1_grad.write(2, c(5.0)) + + # Assert that aggregation works correctly + self.assertAllEqual(c(12.00), w2_grad.read(2).eval()) + + # Using differing shapes causes an exception + wb0_grad = ta_grad.write(1, c(1.0)) + wb1_grad = wb0_grad.write(1, c([1.0])) + + with self.assertRaisesOpError( + r"Mismatched TensorArray sizes"): + wb1_grad.flow.eval() + + def testTensorArrayWriteGradientAddMultipleAdds(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayWriteGradientAddMultipleAdds(dtype) + + def testMultiTensorArray(self): + with self.test_session(), self.test_scope(): + h1 = tensor_array_ops.TensorArray( + size=1, dtype=dtypes.float32, tensor_array_name="foo") + w1 = h1.write(0, 4.0) + r1 = w1.read(0) + + h2 = tensor_array_ops.TensorArray( + size=1, dtype=dtypes.float32, tensor_array_name="bar") + + w2 = h2.write(0, 5.0) + r2 = w2.read(0) + r = r1 + r2 + self.assertAllClose(9.0, r.eval()) + + def _testTensorArrayGradientWriteReadType(self, dtype): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.as_dtype(dtype), + tensor_array_name="foo", + size=3, + infer_shape=False) + + c = lambda x: np.array(x, dtype=dtype) + + value_0 = constant_op.constant(c([[4.0, 5.0]])) + value_1 = constant_op.constant(c([[3.0, 3.5]])) + + w0 = ta.write(0, value_0) + w1 = w0.write(1, value_1) + r0 = w1.read(0) + r1 = w1.read(1) + r0_2 = w1.read(0) + + # Test individual components' gradients + grad_just_r0 = gradients_impl.gradients( + ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])]) + grad_just_r0_vals = session.run(grad_just_r0) + self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0]) + + grad_r0_r0_2 = gradients_impl.gradients( + ys=[r0, r0_2], + xs=[value_0], + grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])]) + grad_r0_r0_2_vals = session.run(grad_r0_r0_2) + self.assertAllEqual(c([[3.0, 2.0]]), grad_r0_r0_2_vals[0]) + + grad_just_r1 = gradients_impl.gradients( + ys=[r1], xs=[value_1], grad_ys=[c([[-2.0, -4.0]])]) + grad_just_r1_vals = session.run(grad_just_r1) + self.assertAllEqual(c([[-2.0, -4.0]]), grad_just_r1_vals[0]) + + # Test combined gradients + grad = gradients_impl.gradients( + ys=[r0, r0_2, r1], + xs=[value_0, value_1], + grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]]), c([[-2.0, -10.0]])]) + grad_vals = session.run(grad) + self.assertEqual(len(grad_vals), 2) + self.assertAllEqual(c([[3.0, 2.0]]), grad_vals[0]) + self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1]) + + def testTensorArrayGradientWriteRead(self): + for dtype in self.numeric_types: + self._testTensorArrayGradientWriteReadType(dtype) + + def _testTensorArrayGradientWritePackConcatAndRead(self): + with self.test_session() as sess, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + clear_after_read=False) + + value_0 = constant_op.constant([-1.0, 1.0]) + value_1 = constant_op.constant([-10.0, 10.0]) + + w0 = ta.write(0, value_0) + w1 = w0.write(1, value_1) + p0 = w1.stack() + r0 = w1.read(0) + s0 = w1.concat() + + # Test gradient accumulation between read(0), pack(), and concat() + with ops.control_dependencies([p0, r0, s0]): + grad_r = gradients_impl.gradients( + ys=[p0, r0, s0], + xs=[value_0, value_1], + grad_ys=[ + [[2.0, 3.0], [4.0, 5.0]], # stack gradient + [-0.5, 1.5], # read(0) gradient + [20.0, 30.0, 40.0, 50.0], # concat gradient + ]) + grad_vals = sess.run(grad_r) # 2 + 2 entries + + self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0]) + self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1]) + + def testTensorArrayGradientWritePackConcatAndRead(self): + self._testTensorArrayGradientWritePackConcatAndRead() + + def testTensorArrayReadTwice(self): + with self.test_session(), self.test_scope(): + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + + ta_readtwice = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + clear_after_read=False) + w_readtwice = ta_readtwice.unstack(value) + r0_readtwice = w_readtwice.read(0) + with ops.control_dependencies([r0_readtwice]): + r1_readtwice = w_readtwice.read(0) + + self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) + + def _testTensorArrayGradientUnpackRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + clear_after_read=False) + + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + + w = ta.unstack(value) + r0 = w.read(0) + r0_1 = w.read(0) + r1 = w.read(1) + + # Test combined gradients + aggregation of read(0) + grad = gradients_impl.gradients( + ys=[r0, r0_1, r1], + xs=[value], + grad_ys=[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]]) + grad_vals = session.run(grad) + + self.assertEqual(len(grad_vals), 1) + self.assertAllEqual([[2.0 - 1.5, 3.0 + 1.5], [4.0, 5.0]], grad_vals[0]) + + def testTensorArrayGradientUnpackRead(self): + self._testTensorArrayGradientUnpackRead() + + def testTensorArrayGradientSplitConcat(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=2) + + value = constant_op.constant( + [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0], [1000.0, -1000.0]]) + + w = ta.split(value, [2, 2]) + r = w.concat() + + # Test combined gradients + grad = gradients_impl.gradients( + ys=[r], + xs=[value], + grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0], + [2000.0, -2000.0]]]) + grad_vals = session.run(grad) + + self.assertEqual(len(grad_vals), 1) + self.assertAllEqual([[2.0, -2.0], [20.0, -20.0], [200.0, -200.0], + [2000.0, -2000.0]], + grad_vals[0]) + + # TODO(phawkins): implement TensorArrayClose + # def testCloseTensorArray(self): + # with self.test_session() as session, self.test_scope(): + # ta = tensor_array_ops.TensorArray( + # dtype=dtypes.float32, tensor_array_name="foo", size=3) + # c1 = ta.close() + # session.run(c1) + + def testSizeTensorArray(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + s = ta.size() + self.assertAllEqual(3, s.eval()) + + # TODO(phawkins): implement TensorArrayClose + # def testWriteCloseTensorArray(self): + # with self.test_session(), self.test_scope(): + # ta = tensor_array_ops.TensorArray( + # dtype=dtypes.float32, + # tensor_array_name="foo", + # size=3, + # infer_shape=False) + # w0 = ta.write(0, [[4.0, 5.0]]) + # w1 = w0.write(1, [3.0]) + # w1.close().run() # Expected to run without problems + + # TODO(phawkins): implement while loops. + # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): + # np_dtype = dtype.as_numpy_dtype + # with self.test_session() as session, self.test_scope(): + # v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)) + # var = variables.Variable(np.arange(100, 105, dtype=np_dtype)) + # state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype)) + # ta = tensor_array_ops.TensorArray( + # dtype=dtype, + # tensor_array_name="foo", + # size=0 if dynamic_size else 3, + # dynamic_size=dynamic_size) + # time_0 = array_ops.identity(0) + + # def body(time, ta_t, state): + # sliced = array_ops.slice( + # v0, begin=array_ops.stack([time, 0]), size=[1, -1]) + # sliced = array_ops.squeeze(sliced) + # out = sliced + var + state + # state += sliced + # ta_t = ta_t.write(time, out) + # return (time + 1, ta_t, state) + + # (unused_0, h_final, unused_2) = control_flow_ops.while_loop( + # cond=lambda time, unused_1, unused_2: time < 3, + # body=body, + # loop_vars=(time_0, ta, state0), + # shape_invariants=(time_0.get_shape(), tensor_shape.unknown_shape(), + # tensor_shape.unknown_shape()), + # parallel_iterations=3) + # vout = h_final.stack() + + # grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5) + # v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0] + # state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0] + # var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0] + + # variables.global_variables_initializer().run() + # state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = ( + # session.run([state0, var, v0, vout, v0_grad, var_grad, state0_grad]) + # ) + # just_v0_grad_t, = session.run([v0_grad]) + + # # state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ] + # # vout = [ v0[0] + var + state[0] | + # # v0[1] + var + state[1] | + # # v0[2] + var + state[2] ] + # # = [ v0[0] + var + state0 | + # # v0[1] + var + state0 + v0[0] | + # # v0[2] + var + state0 + v0[0] + v0[1] ] + # # + # # d(vout[0])/d(v0) = [1 | 0 | 0 ] + # # d(vout[1])/d(v0) = [1 | 1 | 0 ] + # # d(vout[2])/d(v0) = [1 | 1 | 1 ] + # # d(vout)/d(var) = [1 | 1 | 1] + # # d(vout)/d(state0) = [ 1 | 1 | 1 ] + + # state_per_time = np.array( + # [state0_t, state0_t + v0_t[0, :], + # state0_t + v0_t[0, :] + v0_t[1, :]]) + + # # Compare forward prop + # self.assertAllClose(v0_t + var_t + state_per_time, vout_t) + + # # Compare backward prop + # expected_v0_grad_t = np.array([ + # grad_val[0, :] + grad_val[1, :] + grad_val[2, :], + # grad_val[1, :] + grad_val[2, :], grad_val[2, :] + # ]) + + # self.assertAllEqual(expected_v0_grad_t, v0_grad_t) + # self.assertAllEqual(expected_v0_grad_t, just_v0_grad_t) + # self.assertAllClose(grad_val.sum(axis=0), var_grad_t) + # self.assertAllClose(grad_val.sum(axis=0), state0_grad_t) + + # def testWhileLoopWritePackGradients(self): + # self._testWhileLoopWritePackGradients( + # dynamic_size=False, dtype=dtypes.float32) + # # TODO(ebrevdo): re-enable when While supports non-float32 gradients. + # # self._testWhileLoopWritePackGradients( + # # dynamic_size=False, dtype=tf.int64) + + # def testWhileLoopDynamicWritePackGradients(self): + # self._testWhileLoopWritePackGradients( + # dynamic_size=True, dtype=dtypes.float32) + + # def testGradSerialTwoLoops(self): + # with self.test_session(), self.test_scope(): + # num_steps = 100 + # acc = tensor_array_ops.TensorArray( + # dtype=dtypes.float32, + # size=num_steps, + # clear_after_read=False, + # element_shape=tensor_shape.scalar()) + # i = constant_op.constant(0, name="i") + # x = constant_op.constant(2.0, name="x") + + # c = lambda i, acc: i < 5 + + # def b(i, acc): + # x1 = control_flow_ops.cond( + # math_ops.equal(i, 0), lambda: x, + # lambda: math_ops.multiply(acc.read(i - 1), 2.0)) + # return i + 1, acc.write(i, x1) + + # i1, acc1 = control_flow_ops.while_loop(c, b, [i, acc]) + + # z = constant_op.constant(0.0) + + # def fn(i, acc): + # return i + 1, acc.write(i, z) + + # _, acc2 = control_flow_ops.while_loop(lambda i, acc: i < num_steps, fn, + # [i1, acc1]) + + # r = acc2.stack() + # grad = gradients_impl.gradients(r, [x])[0] + # self.assertAllClose(31.0, grad.eval()) + + def testSumOfTwoReadVariablesWithoutRepeatGrad(self): + with self.test_session() as session, self.test_scope(): + a = array_ops.identity( + np.arange( + 3 * 5, dtype=np.float32).reshape(3, 5) + 1) + b = array_ops.identity( + np.arange( + 3 * 5, dtype=np.float32).reshape(3, 5) + 1 + 3 * 5) + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) + ta = ta.write(0, a, name="write_a") + ta = ta.write(1, b, name="write_b") + c = ( + ta.read( + 0, name="read_a_0") + # a + b + ta.read( + 1, name="read_b_0")) + g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1) + grad_a = gradients_impl.gradients([c], [a], [g0])[0] # d(a+b)/da = 1 + grad_b = gradients_impl.gradients([c], [b], [g0])[0] # d(a+b)/db = 1 + + # Test gradients calculated individually + grad_a_t, = session.run([grad_a]) + self.assertAllEqual(grad_a_t, g0) + + grad_b_t, = session.run([grad_b]) + self.assertAllEqual(grad_b_t, g0) + + # Test gradients calculated jointly + joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b]) + self.assertAllEqual(joint_grad_a_t, g0) + self.assertAllEqual(joint_grad_b_t, g0) + + def testWriteShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c0 = constant_op.constant([4.0, 5.0]) + w0 = ta.write(0, c0) + r0 = w0.read(0) + self.assertAllEqual(c0.get_shape(), r0.get_shape()) + + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c1 = constant_op.constant([6.0, 7.0]) + w1 = w0.write(1, c1) + r0 = w1.read(0) + r1 = w1.read(1) + self.assertAllEqual(c0.get_shape(), r0.get_shape()) + self.assertAllEqual(c1.get_shape(), r1.get_shape()) + + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c2 = constant_op.constant([4.0, 5.0, 6.0]) + with self.assertRaises(ValueError): + w0.write(0, c2) + + def testPartlyUnknownShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=6) + + c0 = array_ops.placeholder(dtypes.float32, [None, None, None, 3]) + w0 = ta.write(0, c0) + r0 = w0.read(0) + self.assertAllEqual([None, None, None, 3], r0.get_shape().as_list()) + + c1 = array_ops.placeholder(dtypes.float32, [None, None, None, 3]) + w1 = w0.write(1, c1) + r1 = w1.read(0) + self.assertAllEqual([None, None, None, 3], r1.get_shape().as_list()) + + # Writing less specific shape (doesn't change type.) + c2 = array_ops.placeholder(dtypes.float32, [None, None, None, None]) + w2 = w1.write(2, c2) + r2 = w2.read(0) + self.assertAllEqual([None, None, None, 3], r2.get_shape().as_list()) + + # Writing more specific shape in one dimension and less specific in + # another. + c3 = array_ops.placeholder(dtypes.float32, [None, None, 2, None]) + w3 = w2.write(3, c3) + r3 = w3.read(0) + self.assertAllEqual([None, None, 2, 3], r3.get_shape().as_list()) + + # Writing partly defined shape using TensorArray.scatter. + c4 = array_ops.placeholder(dtypes.float32, [2, None, 4, 2, 3]) + w4 = w3.scatter([4, 5], c4) + r4 = w4.read(0) + self.assertAllEqual([None, 4, 2, 3], r4.get_shape().as_list()) + + # Writing fully defined shape using TensorArray.split. + c5 = array_ops.placeholder(dtypes.float32, [10, 4, 2, 3]) + w5 = w4.split(c5, constant_op.constant([5, 5])) + r5 = w5.read(0) + self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list()) + + def _testUnpackShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=0, + infer_shape=True) + value = constant_op.constant( + [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]]) + w0 = ta.unstack(value) + r0 = w0.read(0) + self.assertAllEqual((2,), r0.get_shape()) + + c1 = constant_op.constant([4.0, 5.0]) + w1 = w0.write(3, c1) + r1 = w1.read(0) + self.assertAllEqual(c1.get_shape(), r1.get_shape()) + + c2 = constant_op.constant([4.0, 5.0, 6.0]) + with self.assertRaises(ValueError): + w1.write(4, c2) + + def testUnpackShape(self): + self._testUnpackShape() + + def testSplitShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=0, + infer_shape=True) + value = constant_op.constant([[1.0, -1.0], [2.0, -2.0], [3.0, -3.0]]) + w0 = ta.split(value, [1, 1, 1]) + r0 = w0.read(0) + self.assertAllEqual((1, 2), r0.get_shape()) + + ta1 = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo1", + size=0, + infer_shape=True) + w0 = ta1.split(value, [1, 2]) + r0 = w0.read(0) + self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) + + def testWriteUnknownShape(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=True) + c0 = array_ops.placeholder(dtypes.float32) + w0 = ta.write(0, c0) + r0 = w0.read(0) + self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) + + def _testGradientWhenNotAllComponentsRead(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) + x = constant_op.constant([2.0, 3.0]) + w = ta.unstack(x) + r0 = w.read(0) + # calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0). + grad_r0 = gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0]) + grad_r0_vals = session.run(grad_r0)[0] + self.assertAllEqual(grad_r0_vals, [1.0, 0.0]) + + def testGradientWhenNotAllComponentsRead(self): + self._testGradientWhenNotAllComponentsRead() + + def _testTensorArrayEvalEmpty(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=0, infer_shape=False) + with self.assertRaisesOpError( + "TensorArray has size zero, but element shape is not fully " + "defined. Currently only static shapes are supported when packing " + "zero-size TensorArrays."): + ta.stack().eval() + + def testTensorArrayEvalEmpty(self): + self._testTensorArrayEvalEmpty() + + def _testTensorArrayEvalEmptyWithDefault(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=0, infer_shape=True) + self.assertEqual(0, ta.size().eval()) + ta = ta.unstack(array_ops.zeros([0, 3, 5])) + packed = ta.stack() + self.assertAllEqual([0, 3, 5], packed.eval().shape) + # Concatenating zero tensors along their first dimension gives a + # first dimension of zero + self.assertAllEqual([0, 5], ta.concat().eval().shape) + + def testTensorArrayEvalEmptyWithDefault(self): + self._testTensorArrayEvalEmptyWithDefault() + + def testTensorArrayScatterReadAndGradients(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=10) + + indices = constant_op.constant([1, 8]) + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + + w = ta.scatter(indices, value) + r0 = w.read(1) + r1 = w.read(8) + + # Test combined gradients + aggregation of read(0) + grad = gradients_impl.gradients( + ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) + read_vals, grad_vals = session.run([[r0, r1], grad]) + + self.assertEqual(len(read_vals), 2) + self.assertEqual(len(grad_vals), 1) + self.assertAllEqual([1.0, -1.0], read_vals[0]) + self.assertAllEqual([10.0, -10.0], read_vals[1]) + self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) + + def testTensorArrayWriteGatherAndGradients(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=10) + + values = constant_op.constant([[1.0 * x, -1.0 * x] for x in range(10)]) + indices = constant_op.constant([1, 8]) + + w = ta.unstack(values) + g = w.gather(indices) + + # Test combined gradients + aggregation of read(0) + grad = gradients_impl.gradients( + ys=[g], xs=[values], grad_ys=[[[2.0, 3.0], [4.0, 5.0]]]) + g_vals, grad_vals = session.run([[g], grad]) + + # Gradients for 8 of the 10 unread components are zero. + expected_grad = np.zeros((10, 2)) + expected_grad[1] = [2.0, 3.0] + expected_grad[8] = [4.0, 5.0] + + self.assertEqual(len(g_vals), 1) + self.assertEqual(len(grad_vals), 1) + self.assertAllEqual([[1.0, -1.0], [8.0, -8.0]], g_vals[0]) + self.assertAllEqual(expected_grad, grad_vals[0]) + + def testTensorArrayIdentity(self): + with self.test_session() as session, self.test_scope(): + ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2, + infer_shape=False) + ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4, + infer_shape=True) + + ta0 = ta0.write(0, 0.) + ta1 = ta1.write(0, 1) + + v0 = resource_variable_ops.ResourceVariable(0) + v1 = resource_variable_ops.ResourceVariable(0) + + with ops.control_dependencies([v0.assign_add(1)]): + ta0 = ta0.identity() + + with ops.control_dependencies([v1.assign_add(1)]): + ta1 = ta1.identity() + + read0 = ta0.read(0) + read1 = ta1.read(0) + + size0 = ta0.size() + size1 = ta1.size() + + # Tests correct properties on new TensorArrays. + self.assertEqual(dtypes.float32, ta0.dtype) + self.assertEqual(dtypes.int32, ta1.dtype) + self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape()) + self.assertEqual(tensor_shape.scalar(), read1.get_shape()) + + variables.global_variables_initializer().run() + + read0_v, read1_v, size0_v, size1_v = session.run( + (read0, read1, size0, size1)) + + # Tests that the control dependencies was added and executed. + self.assertEqual(1, v0.eval()) + self.assertEqual(1, v1.eval()) + + # Tests correct TensorArray. + self.assertEqual(read0_v, 0) + self.assertEqual(read1_v, 1) + self.assertEqual(size0_v, 2) + self.assertEqual(size1_v, 4) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 22024f45116..ba5f829936f 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -75,6 +75,20 @@ class TernaryOpsTest(XLATestCase): np.array(7, dtype=np.float32), expected=np.array(7, dtype=np.float32)) + self._testTernary( + array_ops.where, + np.array(1, dtype=np.bool), + np.array([1, 2, 3, 4], dtype=np.float32), + np.array([5, 6, 7, 8], dtype=np.float32), + expected=np.array([1, 2, 3, 4], dtype=np.float32)) + + self._testTernary( + array_ops.where, + np.array(0, dtype=np.bool), + np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), + np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32), + expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32)) + self._testTernary( array_ops.where, np.array([0, 1, 1, 0], dtype=np.bool), diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index f0b80d1ffdb..51d8786ce3d 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -159,6 +159,13 @@ class UnaryOpsTest(XLATestCase): np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.round, + np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5]], dtype=dtype), + expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.rsqrt, np.array([[4, 16]], dtype=dtype), @@ -175,6 +182,11 @@ class UnaryOpsTest(XLATestCase): [0.7310586, 0.880797, 0.95257413, 0.98201376]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, + np.array([-300, -150, 0, 150, 300], dtype=dtype), + expected=np.array([0, 0, 0.5, 1, 1], dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.sqrt, np.array([[4, 9]], dtype=dtype), @@ -202,6 +214,11 @@ class UnaryOpsTest(XLATestCase): [-3.4401896, -2.4401896, -1.4401897, -0.44018969]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.elu, + np.array([[-1, 0, 1]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.relu, np.array([[-1, 1]], dtype=dtype), @@ -250,6 +267,11 @@ class UnaryOpsTest(XLATestCase): np.array([[4, 3], [2, 1]], dtype=dtype), expected=np.array([[0, 0], [0, 0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.ones_like, + np.array([[4, 3], [2, 1]], dtype=dtype), + expected=np.array([[1, 1], [1, 1]], dtype=dtype)) + def testLogicalOps(self): self._assertOpOutputMatchesExpected( math_ops.logical_not, diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py new file mode 100644 index 00000000000..70dacd9de4b --- /dev/null +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -0,0 +1,183 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for reading and writing variables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.training.gradient_descent import GradientDescentOptimizer + + +class VariableOpsTest(XLATestCase): + """Test cases for resource variable operators.""" + + def testOneWriteOneOutput(self): + # Regression test for a bug where computations with one non-constant + # output and one variable update were mishandled. + for dtype in self.numeric_types: + init = np.array([[1, 2], [3, 4]], dtype=dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + p = array_ops.placeholder(dtype) + x = v.assign_add(p) + with ops.control_dependencies([x]): + y = v.read_value() + self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype), + sess.run(y, {p: 1})) + + def testReadWrite(self): + """Tests initialization, reading, and writing a resource variable.""" + with self.test_session() as session: + with self.test_scope(): + with variable_scope.variable_scope("ascope", use_resource=True): + x = variable_scope.get_variable( + "x", + shape=[], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(2)) + a = x.read_value() + with ops.control_dependencies([a]): + b = state_ops.assign(x, 47) + with ops.control_dependencies([b]): + c = x.read_value() + with ops.control_dependencies([c]): + d = state_ops.assign_add(x, 3) + with ops.control_dependencies([d]): + e = x.read_value() + + session.run(variables.global_variables_initializer()) + v1, v2, v3 = session.run([a, c, e]) + self.assertAllClose(2.0, v1) + self.assertAllClose(47.0, v2) + self.assertAllClose(50.0, v3) + + def testTraining(self): + """Tests a gradient descent step for a simple model.""" + with self.test_session() as session: + with self.test_scope(): + with variable_scope.variable_scope("ascope", use_resource=True): + w = variable_scope.get_variable( + "w", + shape=[4, 2], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32))) + b = variable_scope.get_variable( + "b", + shape=[2], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + np.array([2, 3], dtype=np.float32))) + + x = array_ops.placeholder(dtypes.float32, shape=[1, 4]) + y = math_ops.matmul(x, w) + b + loss = math_ops.reduce_sum(y) + optimizer = GradientDescentOptimizer(0.1) + train = optimizer.minimize(loss) + + session.run(variables.global_variables_initializer()) + session.run(train, {x: np.array([[7, 3, 5, 9]], dtype=np.float32)}) + vw, vb = session.run([w, b]) + self.assertAllClose( + np.array( + [[0.3, 1.3], [2.7, 3.7], [4.5, 5.5], [6.1, 7.1]], + dtype=np.float32), + vw, + rtol=1e-4) + self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4) + + +class StridedSliceAssignChecker(object): + """Compares the results of a slice assignment using Tensorflow and numpy.""" + + def __init__(self, test, x, dtype): + self.dtype = dtype + self.test = test + self.x_np = np.array(x).astype(dtype) + + def __setitem__(self, index, value): + value = np.array(value).astype(self.dtype) + + with self.test.test_session() as sess, self.test.test_scope(): + x = constant_op.constant(self.x_np, dtype=self.dtype) + var = resource_variable_ops.ResourceVariable(x) + sess.run(variables.variables_initializer([var])) + val = sess.run(var[index].assign(value)) + # val_copy is used to check that tf.assign works equivalently to the + # assign method above. + val_copy = sess.run(state_ops.assign(var[index], value)) + valnp = np.copy(self.x_np) + valnp[index] = np.array(value) + self.test.assertAllEqual(val, valnp) + self.test.assertAllEqual(val_copy, valnp) + + +class SliceAssignTest(XLATestCase): + + def testSliceAssign(self): + for dtype in self.numeric_types: + checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]], + dtype=dtype) + # No-op assignment + checker[:] = [[10, 20, 30], [40, 50, 60]] + # Checks trivial (1,1) shape tensor + checker[1:2, 1:2] = [[66]] + # shrink shape changes + checker[1:2, 1] = [66] + checker[1, 1:2] = [66] + checker[1, 1] = 66 + # newaxis shape changes + checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]] + # shrink and newaxis + checker[None, None, 0, 0:1] = [[[99]]] + # Non unit strides + checker[::1, 1::-1] = [[3, 33], [4, 44]] + # degenerate interval + checker[8:10, 0] = [] + checker[8:10, 8:10] = [[]] + + # Assign vector to scalar (rank-0) using newaxis + checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype) + checker2[()] = 6 # no indices + checker2[...] = 6 # ellipsis + checker2[None] = [6] # new axis + + def testUninitialized(self): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "uninitialized variable"): + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable([1, 2]) + sess.run(v[:].assign([1, 2])) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 1388a892ba5..f5c228f8305 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -18,15 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -48,34 +43,6 @@ class XlaDeviceTest(test.TestCase): result = sess.run(w, {x: [1.5, 0.5]}) self.assertAllClose(result, [12., 2.], rtol=1e-3) - def testLoops(self): - """Tests that loops work on XLA devices.""" - - with session_lib.Session() as session: - x = array_ops.placeholder(dtypes.float32) - with ops.device("device:XLA_CPU:0"): - c = lambda i, _: math_ops.less(i, 5) - b = lambda i, x: (i + 1, x * 2.0 + 1.0) - _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x)) - - result = session.run(y, {x: np.float32(2)}) - self.assertAllClose(result, np.float32(95), rtol=1e-3) - - def testCond(self): - """Tests that tf.cond works on XLA devices.""" - - with session_lib.Session() as session: - x = array_ops.placeholder(dtypes.float32) - y = array_ops.placeholder(dtypes.float32) - c = array_ops.placeholder(dtypes.bool) - with ops.device("device:XLA_CPU:0"): - z = x + 1.0 - w = control_flow_ops.cond(c, lambda: z, lambda: y) - t = math_ops.add(z, w) - - result = session.run(t, {x: np.float32(2), y: np.float32(4), c: True}) - self.assertAllClose(result, np.float32(6), rtol=1e-3) - if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index b72e7c9713d..79549644ea0 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -19,14 +19,18 @@ from __future__ import division from __future__ import print_function import contextlib +import random import re +import numpy as np + from tensorflow.contrib.compiler import jit from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import flags @@ -50,16 +54,20 @@ class XLATestCase(test.TestCase): self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU') self.all_tf_types = [ - dtypes.DType(types_pb2.DataType.Value(name)) + dtypes.as_dtype(types_pb2.DataType.Value(name)) for name in FLAGS.types.split(',') ] + self.int_tf_types = [ + dtype for dtype in self.all_tf_types if dtype.is_integer + ] + self.float_tf_types = [ + dtype for dtype in self.all_tf_types if dtype.is_floating + ] + self.numeric_tf_types = self.int_tf_types + self.float_tf_types + self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types] - self.int_types = [ - dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer - ] - self.float_types = [ - dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating - ] + self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types] + self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types] self.numeric_types = self.int_types + self.float_types # Parse the manifest file, if any, into a regex identifying tests to @@ -81,6 +89,9 @@ class XLATestCase(test.TestCase): return logging.info('Start test case: %s', name) + random.seed(random_seed.DEFAULT_GRAPH_SEED) + np.random.seed(random_seed.DEFAULT_GRAPH_SEED) + def tearDown(self): logging.info('End test case: %s', self._testMethodName) @@ -112,7 +123,11 @@ class XLATestCase(test.TestCase): yield -def Benchmark(tf_bench, builder_fn, use_xla_jit, device): +def Benchmark(tf_bench, + builder_fn, + use_xla_jit, + device, + separate_compiled_gradients=False): """Build a graph and run benchmarks against it, with or without XLA. Args: @@ -122,6 +137,14 @@ def Benchmark(tf_bench, builder_fn, use_xla_jit, device): is a list of tensors to fetch as output. use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF. device: The tensorflow device to run on, e.g. "cpu", "gpu". + separate_compiled_gradients: If true put each gradient subgraph into a + separate compilation scope. This gives fine-grained control over which + portions of the graph will be compiled as a single unit. Compiling + gradients separately may yield better performance for some graphs. + The scope is named based on the scope of the forward computation as well + as the name of the gradients. As a result, the gradients will be compiled + in a scope that is separate from both the forward computation, and from + other gradients. """ with ops.Graph().as_default(): @@ -130,7 +153,9 @@ def Benchmark(tf_bench, builder_fn, use_xla_jit, device): with ops.device(device): fetches = [] jit_scope = jit.experimental_jit_scope - with jit_scope(compile_ops=use_xla_jit): + with jit_scope( + compile_ops=use_xla_jit, + separate_compiled_gradients=separate_compiled_gradients): name, fetches = builder_fn() # We only want to benchmark the operations themselves, and not the data diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 08a03b8d357..93c484ca7a0 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -23,12 +23,12 @@ package( cc_library( name = "xla_compiler", srcs = [ - "op_registrations.cc", "xla_compilation_device.cc", "xla_compiler.cc", "xla_context.cc", "xla_helpers.cc", "xla_op_kernel.cc", + "xla_op_registry.cc", ], hdrs = [ "xla_compilation_device.h", @@ -36,18 +36,21 @@ cc_library( "xla_context.h", "xla_helpers.h", "xla_op_kernel.h", + "xla_op_registry.h", ], + visibility = [":friends"], deps = [ ":common", ":dump_graph", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -89,6 +92,7 @@ cc_library( cc_test( name = "xla_compiler_test", + size = "small", srcs = ["xla_compiler_test.cc"], deps = [ ":xla_compiler", @@ -110,6 +114,7 @@ cc_test( cc_test( name = "str_util_test", + size = "small", srcs = [ "str_util_test.cc", ], @@ -123,6 +128,7 @@ cc_test( cc_test( name = "literal_util_test", + size = "small", srcs = [ "literal_util_test.cc", ], diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index e072ef7be7e..36a6c90af4f 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -35,6 +35,9 @@ Status BackwardsConstAnalysis(const Graph& g, {"Any", "reduction_indices"}, {"ArgMax", "dimension"}, {"AvgPoolGrad", "orig_input_shape"}, + {"BatchToSpace", "crops"}, + {"BatchToSpaceND", "block_shape"}, + {"BatchToSpaceND", "crops"}, {"BroadcastGradientArgs", "s0"}, {"BroadcastGradientArgs", "s1"}, {"Concat", "concat_dim"}, @@ -43,6 +46,8 @@ Status BackwardsConstAnalysis(const Graph& g, {"ConcatOffset", "shape"}, {"Conv2DBackpropFilter", "filter_sizes"}, {"Conv2DBackpropInput", "input_sizes"}, + {"Conv3DBackpropFilterV2", "filter_sizes"}, + {"Conv3DBackpropInputV2", "input_sizes"}, {"DynamicStitch", "indices"}, {"ExpandDims", "dim"}, {"Fill", "dims"}, @@ -53,6 +58,7 @@ Status BackwardsConstAnalysis(const Graph& g, {"Max", "reduction_indices"}, {"Mean", "reduction_indices"}, {"Min", "reduction_indices"}, + {"OneHot", "depth"}, {"Pad", "paddings"}, {"Prod", "reduction_indices"}, {"RandomStandardNormal", "shape"}, @@ -62,8 +68,16 @@ Status BackwardsConstAnalysis(const Graph& g, {"Range", "limit"}, {"Range", "delta"}, {"Reshape", "shape"}, + {"ResourceStridedSliceAssign", "begin"}, + {"ResourceStridedSliceAssign", "end"}, + {"ResourceStridedSliceAssign", "strides"}, + {"Reverse", "dims"}, + {"ReverseV2", "axis"}, {"Slice", "begin"}, {"Slice", "size"}, + {"SpaceToBatch", "paddings"}, + {"SpaceToBatchND", "block_shape"}, + {"SpaceToBatchND", "paddings"}, {"Split", "split_dim"}, {"SplitV", "split_dim"}, {"SplitV", "size_splits"}, @@ -75,6 +89,8 @@ Status BackwardsConstAnalysis(const Graph& g, {"StridedSliceGrad", "end"}, {"StridedSliceGrad", "strides"}, {"Sum", "reduction_indices"}, + {"TensorArrayV3", "size"}, + {"TensorArraySplitV3", "lengths"}, {"Tile", "multiples"}, {"Transpose", "perm"}}; @@ -97,7 +113,7 @@ Status BackwardsConstAnalysis(const Graph& g, if (must_be_const.find(node) != must_be_const.end()) { if (node->type_string() == "_Arg") { int index; - status = GetNodeAttr(node->def(), "index", &index); + status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; compile_time_const_args->at(index) = true; return; @@ -113,8 +129,8 @@ Status BackwardsConstAnalysis(const Graph& g, if (range.first == range.second) return; NameRangeMap input_name_ranges; - status = NameRangesForNode(node->def(), node->op_def(), &input_name_ranges, - nullptr); + status = + NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr); if (!status.ok()) return; for (auto it = range.first; it != range.second; ++it) { diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 5aa6f806ac6..af5753c2600 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -33,8 +33,16 @@ struct NameCounts { std::unordered_map counts; }; -string MakeUniquePath(const string& name) { +string MakeUniquePath(string name) { static NameCounts& instance = *new NameCounts; + + // Remove illegal characters from `name`. + for (int i = 0; i < name.size(); ++i) { + if (name[i] == '/') { + name[i] = '_'; + } + } + int count; { mutex_lock lock(instance.counts_mutex); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d913f898e94..a434c746809 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -14,18 +14,21 @@ tf_kernel_library( name = "xla_ops", srcs = [ "aggregate_ops.cc", + "arg_op.cc", "batch_matmul_op.cc", + "batchtospace_op.cc", "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", "cast_op.cc", "concat_op.cc", + "const_op.cc", "conv_ops.cc", "cwise_ops.cc", - "declaration_op.cc", "depthwise_conv_ops.cc", "diag_op.cc", "dynamic_stitch_op.cc", + "elu_op.cc", "fill_op.cc", "function_ops.cc", "identity_op.cc", @@ -33,6 +36,7 @@ tf_kernel_library( "lrn_ops.cc", "matmul_op.cc", "no_op.cc", + "one_hot_op.cc", "pack_op.cc", "pad_op.cc", "pooling_ops.cc", @@ -42,17 +46,22 @@ tf_kernel_library( "relu_op.cc", "reshape_op.cc", "retval_op.cc", + "reverse_op.cc", "select_op.cc", "sequence_ops.cc", "shape_op.cc", "slice_op.cc", "softmax_op.cc", + "spacetobatch_op.cc", "split_op.cc", "strided_slice_op.cc", + "tensor_array_ops.cc", "tile_ops.cc", + "training_ops.cc", "transpose_op.cc", "unary_ops.cc", "unpack_op.cc", + "variable_ops.cc", ], hdrs = [ "cwise_ops.h", diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 8f284c30174..5c9f66df101 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" namespace tensorflow { namespace { @@ -41,7 +41,7 @@ class AddNOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(AddNOp); }; -REGISTER_XLA_OP("AddN", AddNOp); +REGISTER_XLA_OP(Name("AddN"), AddNOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc new file mode 100644 index 00000000000..620fc844378 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +// This OpKernel implements the _Arg Op for XLA JIT devices. It +// associates its output with one of the arguments to a +// subcomputation. +class ArgOp : public XlaOpKernel { + public: + explicit ArgOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // If 'frame' is non-null, this is a function call inside an outer JIT + // compilation. Use the usual implementation of _Arg. + auto frame = ctx->call_frame(); + if (frame != nullptr) { + Tensor val; + OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); + OP_REQUIRES(ctx, val.dtype() == dtype_, + errors::InvalidArgument( + "Type mismatch: actual ", DataTypeString(val.dtype()), + " vs. expect ", DataTypeString(dtype_))); + // Forwards the argument from the frame. + ctx->op_kernel_context()->set_output(0, val); + return; + } + + XlaContext& xc = XlaContext::Get(ctx); + const XlaContext::Argument& arg = xc.args()[index_]; + if (arg.is_variable) { + // TODO(phawkins): this code assumes that variables do not alias. + XlaVariable* var; + OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type, + arg.value.handle, &var)); + var->tensor_array_size = arg.tensor_array_size; + ctx->SetVariableOutput(0, var); + } else if (arg.value.is_constant) { + ctx->SetConstantOutput(0, arg.value.constant_value); + } else { + ctx->SetOutput(0, arg.value.handle); + } + } + + private: + int index_; + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); +}; + +REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), ArgOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 637360d149e..16b778bca43 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -18,8 +18,8 @@ limitations under the License. // dimension. // TODO(dominikg,phawkins): Use a real batched matmul instead of unrolling. -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" namespace tensorflow { namespace { @@ -94,12 +94,14 @@ class BatchMatMulOp : public XlaOpKernel { // Slice off individual matrices and reshape to 2D tensors. auto x_slice = builder->Slice( x_flat, {i, 0, 0}, - {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}); + {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}, + {1, 1, 1}); x_slice = builder->Reshape( x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}); auto y_slice = builder->Slice( y_flat, {i, 0, 0}, - {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}); + {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}, + {1, 1, 1}); y_slice = builder->Reshape( y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}); @@ -135,7 +137,7 @@ class BatchMatMulOp : public XlaOpKernel { bool adj_y_; }; -REGISTER_XLA_OP("BatchMatMul", BatchMatMulOp); +REGISTER_XLA_OP(Name("BatchMatMul"), BatchMatMulOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc new file mode 100644 index 00000000000..8642cbf2a92 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -0,0 +1,187 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +void BatchToSpace(XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, DataType input_dtype, + const TensorShape& input_tensor_shape, + gtl::ArraySlice block_shape, + const xla::Literal& crops) { + const int input_rank = input_tensor_shape.dims(); + const gtl::InlinedVector input_shape = + input_tensor_shape.dim_sizes(); + const int block_rank = block_shape.size(); + + OP_REQUIRES( + ctx, input_rank >= 1 + block_rank, + errors::InvalidArgument("input rank should be >= ", 1 + block_rank, + " instead of ", input_rank)); + gtl::ArraySlice remainder_shape(input_shape); + remainder_shape.remove_prefix(1 + block_rank); + + OP_REQUIRES( + ctx, + xla::ShapeUtil::Rank(crops.shape()) == 2 && + block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) && + 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1), + errors::InvalidArgument("crops should have shape [", block_rank, + ", 2] instead of ", + xla::ShapeUtil::HumanString(crops.shape()))); + + xla::ComputationBuilder* b = ctx->builder(); + const int64 batch_size = input_shape[0]; + + // Compute the product of the block_shape values. + int64 block_num_elems = 1; + for (int i = 0; i < block_rank; ++i) { + block_num_elems *= block_shape[i]; + } + OP_REQUIRES(ctx, block_num_elems > 0, + errors::InvalidArgument( + "The product of the block dimensions must be positive")); + + // 1. Reshape `input` to `reshaped` of shape: + // [block_shape[0], ..., block_shape[M-1], + // batch / prod(block_shape), + // input_shape[1], ..., input_shape[N-1]] + + OP_REQUIRES( + ctx, batch_size % block_num_elems == 0, + errors::InvalidArgument("Input batch dimension (", batch_size, + ") is not divisible by product of block sizes (", + block_num_elems, ")")); + std::vector reshaped_shape(input_rank + block_rank); + std::copy(block_shape.begin(), block_shape.end(), reshaped_shape.begin()); + reshaped_shape[block_rank] = batch_size / block_num_elems; + std::copy(input_shape.begin() + 1, input_shape.end(), + reshaped_shape.begin() + block_rank + 1); + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + + // 2. Permute dimensions of `reshaped` to produce `permuted` of shape + // [batch / prod(block_shape), + // + // input_shape[1], block_shape[0], + // ..., + // input_shape[M], block_shape[M-1], + // + // input_shape[M+1], ..., input_shape[N-1]] + std::vector permutation(reshaped_shape.size()); + permutation[0] = block_rank; + for (int i = 0; i < block_rank; ++i) { + permutation[1 + 2 * i] = block_rank + 1 + i; + permutation[1 + 2 * i + 1] = i; + } + std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), + 1 + block_rank * 2); + xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation); + + // 3. Reshape `permuted` to produce `reshaped_permuted` of shape + // [batch / prod(block_shape), + // + // input_shape[1] * block_shape[0], + // ..., + // input_shape[M] * block_shape[M-1], + // + // input_shape[M+1], + // ..., + // input_shape[N-1]] + std::vector reshaped_permuted_shape(input_rank); + reshaped_permuted_shape[0] = batch_size / block_num_elems; + for (int i = 0; i < block_rank; ++i) { + reshaped_permuted_shape[1 + i] = block_shape[i] * input_shape[1 + i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + reshaped_permuted_shape.begin() + 1 + block_rank); + + xla::ComputationDataHandle reshaped_permuted = + b->Reshape(permuted, reshaped_permuted_shape); + + // 4. Crop the start and end of dimensions `[1, ..., M]` of + // `reshaped_permuted` according to `crops` to produce the output of shape: + // [batch / prod(block_shape), + // + // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], + // ..., + // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], + // + // input_shape[M+1], ..., input_shape[N-1]] + std::vector start_indices(input_rank, 0); + std::vector end_indices = reshaped_permuted_shape; + std::vector strides(input_rank, 1); + for (int i = 0; i < block_rank; ++i) { + int64 crop_start = xla::LiteralUtil::Get(crops, {i, 0}); + int64 crop_end = xla::LiteralUtil::Get(crops, {i, 1}); + OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0, + errors::InvalidArgument("Crops must be non-negative")); + start_indices[1 + i] = crop_start; + end_indices[1 + i] -= crop_end; + OP_REQUIRES( + ctx, start_indices[1 + i] <= end_indices[1 + i], + errors::InvalidArgument( + "Cropped size must be non-negative: start: ", crop_start, + " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); + } + xla::ComputationDataHandle output = + b->Slice(reshaped_permuted, start_indices, end_indices, strides); + ctx->SetOutput(0, output); +} + +class BatchToSpaceNDOp : public XlaOpKernel { + public: + explicit BatchToSpaceNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + std::vector block_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape)); + + xla::Literal crops; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &crops)); + + BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + block_shape, crops); + } +}; +REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp); + +class BatchToSpaceOp : public XlaOpKernel { + public: + explicit BatchToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::Literal crops; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &crops)); + + BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + {block_size_, block_size_}, crops); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index f35835df087..b0fee5e4bca 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -16,9 +16,9 @@ limitations under the License. // XLA-specific Ops for broadcasting used in gradient // code. -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -81,7 +81,7 @@ class BCastGradArgsOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); }; -REGISTER_XLA_OP("BroadcastGradientArgs", BCastGradArgsOp); +REGISTER_XLA_OP(Name("BroadcastGradientArgs"), BCastGradArgsOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 217e82304e3..c667b4e3e32 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -15,9 +15,9 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" @@ -69,8 +69,8 @@ class BiasOp : public XlaOpKernel { TensorFormat data_format_; }; -REGISTER_XLA_OP("BiasAdd", BiasOp); -REGISTER_XLA_OP("BiasAddV1", BiasOp); +REGISTER_XLA_OP(Name("BiasAdd"), BiasOp); +REGISTER_XLA_OP(Name("BiasAddV1"), BiasOp); class BiasAddGradOp : public XlaOpKernel { public: @@ -113,7 +113,7 @@ class BiasAddGradOp : public XlaOpKernel { TensorFormat data_format_; }; -REGISTER_XLA_OP("BiasAddGrad", BiasAddGradOp); +REGISTER_XLA_OP(Name("BiasAddGrad"), BiasAddGradOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 6f117ebe616..ded20a9a3ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -16,8 +16,8 @@ limitations under the License. // Native XLA implementations of simple unary Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -28,10 +28,10 @@ namespace { // A subclass of a XlaBinaryOp must build the computation that // describes the (tensor,tensor)->tensor function to apply to each element of // the input. -#define XLA_MAKE_BINARY(Name, HLO) \ - class Name##Op : public XlaBinaryOp { \ +#define XLA_MAKE_BINARY(NAME, HLO) \ + class NAME##Op : public XlaBinaryOp { \ public: \ - explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ + explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ xla::ComputationDataHandle Computation( \ XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, \ const gtl::ArraySlice& lhs_shape, \ @@ -43,7 +43,7 @@ namespace { return HLO; \ } \ }; \ - REGISTER_XLA_OP(#Name, Name##Op) + REGISTER_XLA_OP(Name(#NAME), NAME##Op) XLA_MAKE_BINARY(Add, b->Add(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions)); @@ -127,32 +127,21 @@ XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions)); +// Non-linear ops +XLA_MAKE_BINARY(SigmoidGrad, + b->Mul(b->Mul(rhs, lhs), + b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); + +XLA_MAKE_BINARY(SoftplusGrad, + b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), + XlaHelpers::One(b, input_type(1))))); + +XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(lhs, lhs)))); + +XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions)); + #undef XLA_MAKE_BINARY -#define XLA_MAKE_BINARY_MAP(Name, HLO) \ - class Name##Op : public XlaBinaryMapOp { \ - public: \ - explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} \ - void BuildMapLambda(xla::ComputationBuilder* b, \ - const xla::ComputationDataHandle& lhs, \ - const xla::ComputationDataHandle& rhs) override { \ - HLO; \ - } \ - }; \ - REGISTER_XLA_OP(#Name, Name##Op) - -XLA_MAKE_BINARY_MAP(Pow, b->Pow(lhs, rhs)); -XLA_MAKE_BINARY_MAP(SigmoidGrad, - b->Mul(b->Mul(rhs, lhs), - b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); -XLA_MAKE_BINARY_MAP(SoftplusGrad, - b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), - XlaHelpers::One(b, input_type(1))))); -XLA_MAKE_BINARY_MAP(TanhGrad, - b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(lhs, lhs)))); - -#undef XLA_MAKE_BINARY_MAP - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index b0188b4f8d8..124e33d7935 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -65,7 +65,7 @@ class CastOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(CastOp); }; -REGISTER_XLA_OP("Cast", CastOp); +REGISTER_XLA_OP(Name("Cast"), CastOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index d086e55cb79..e2eacb3839d 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -117,8 +117,8 @@ class ConcatV2Op : public ConcatBaseOp { : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} }; -REGISTER_XLA_OP("Concat", ConcatOp); -REGISTER_XLA_OP("ConcatV2", ConcatV2Op); +REGISTER_XLA_OP(Name("Concat"), ConcatOp); +REGISTER_XLA_OP(Name("ConcatV2").TypeConstraint("Tidx", DT_INT32), ConcatV2Op); class ConcatOffsetOp : public XlaOpKernel { public: @@ -204,7 +204,7 @@ class ConcatOffsetOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("ConcatOffset", ConcatOffsetOp); +REGISTER_XLA_OP(Name("ConcatOffset"), ConcatOffsetOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc new file mode 100644 index 00000000000..ad676e7a2bb --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +class ConstOp : public XlaOpKernel { + public: + explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const TensorProto* proto = nullptr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); + proto_ = *proto; + OP_REQUIRES( + ctx, ctx->output_type(0) == proto_.dtype(), + errors::InvalidArgument("Type mismatch between value (", + DataTypeString(proto_.dtype()), ") and dtype (", + DataTypeString(ctx->output_type(0)), ")")); + OP_REQUIRES_OK(ctx, TensorShape::IsValidShape(proto_.tensor_shape())); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape(proto_.tensor_shape()); + + xla::ComputationBuilder* b = ctx->builder(); + + // To avoid blowups for large constants filled with the same value, + // recognize that case and emit a scalar broadcast instead. + if (shape.num_elements() > 1) { + switch (proto_.dtype()) { + case DT_BOOL: + if (proto_.bool_val_size() == 1) { + ctx->SetOutput(0, + b->Broadcast(b->ConstantR0(proto_.bool_val(0)), + shape.dim_sizes())); + return; + } + break; + case DT_FLOAT: + if (proto_.float_val_size() == 1) { + ctx->SetOutput( + 0, b->Broadcast(b->ConstantR0(proto_.float_val(0)), + shape.dim_sizes())); + return; + } + break; + case DT_DOUBLE: + if (proto_.double_val_size() == 1) { + ctx->SetOutput( + 0, b->Broadcast(b->ConstantR0(proto_.double_val(0)), + shape.dim_sizes())); + return; + } + break; + case DT_INT32: + if (proto_.int_val_size() == 1) { + ctx->SetOutput(0, + b->Broadcast(b->ConstantR0(proto_.int_val(0)), + shape.dim_sizes())); + return; + } + break; + case DT_INT64: + if (proto_.int64_val_size() == 1) { + ctx->SetOutput( + 0, b->Broadcast(b->ConstantR0(proto_.int64_val(0)), + shape.dim_sizes())); + return; + } + break; + default: + break; + } + } + + // General case + Tensor tensor(proto_.dtype()); + OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), + errors::InvalidArgument("Cannot parse tensor from proto: ", + proto_.DebugString())); + ctx->SetConstantOutput(0, tensor); + } + + private: + TensorProto proto_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstOp); +}; + +// XLA_* devices also register a "real" Const operator so we suppress the +// dummy operator using CompilationOnly(). +REGISTER_XLA_OP(Name("Const").CompilationOnly(), ConstOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 9bebfcfe47d..67a0b803c5b 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -15,9 +15,9 @@ limitations under the License. // XLA-specific Ops for 2D convolution. -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -35,96 +35,67 @@ namespace tensorflow { namespace { -class Conv2DOp : public XlaOpKernel { +class ConvOp : public XlaOpKernel { public: - explicit Conv2DOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES(ctx, strides_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); - const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); - OP_REQUIRES( - ctx, stride_n == 1 && stride_c == 1, - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); } + int num_dims() const { return num_spatial_dims_ + 2; } + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES(ctx, strides_.size() == num_dims(), + errors::InvalidArgument("Sliding window strides field must " + "specify ", + num_dims(), " dimensions")); + int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); + int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); + OP_REQUIRES( + ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + const TensorShape input_shape = ctx->InputShape(0); // Input filter is of the following dimensions: - // [ filter_rows, filter_cols, in_depth, out_depth] + // [ filter_rows, filter_cols, ..., in_depth, out_depth] const TensorShape filter_shape = ctx->InputShape(1); // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES(ctx, input_shape.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input_shape.DebugString())); - OP_REQUIRES(ctx, filter_shape.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", - filter_shape.DebugString())); + OP_REQUIRES( + ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("input must be ", num_dims(), "-dimensional", + input_shape.DebugString())); + OP_REQUIRES( + ctx, filter_shape.dims() == num_dims(), + errors::InvalidArgument("filter must be ", num_dims(), + "-dimensional: ", filter_shape.DebugString())); + + // The last two dimension of the filter are the input and output shapes. + const int64 in_depth = filter_shape.dim_size(num_spatial_dims_); // The 'C' dimension for input is in_depth. It must be the same as // the filter's in_depth. - const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); - OP_REQUIRES( - ctx, in_depth == filter_shape.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - in_depth, " vs ", filter_shape.dim_size(2))); - - // The last dimension for filter is out_depth. - const int64 out_depth = filter_shape.dim_size(3); - - // The 'H' dimension for input is rows/height. - // The first dimension for filter is rows/height. - const int64 input_rows = GetTensorDim(input_shape, data_format_, 'H'); - const int64 filter_rows = filter_shape.dim_size(0); - - // The 'W' dimension for input is columns/width. - // The second dimension for filter is columns/width. - const int64 input_cols = GetTensorDim(input_shape, data_format_, 'W'); - const int64 filter_cols = filter_shape.dim_size(1); - - // For now we take the stride from the H and W dimensions only (we - // do not support striding on the batch or depth dimension). - const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); - const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(ctx, - GetWindowedOutputSize(input_rows, filter_rows, stride_rows, - padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(ctx, - GetWindowedOutputSize(input_cols, filter_cols, stride_cols, - padding_, &out_cols, &pad_cols)); - - VLOG(2) << "Conv2D: in_depth = " << in_depth - << ", input_cols = " << input_cols - << ", filter_cols = " << filter_cols - << ", input_rows = " << input_rows - << ", filter_rows = " << filter_rows - << ", stride_rows = " << stride_rows - << ", stride_cols = " << stride_cols - << ", out_depth = " << out_depth; + OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim), + errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, + " vs ", input_shape.dim_size(feature_dim))); xla::ConvolutionDimensionNumbers dims; - dims.set_batch_dimension(GetTensorDimIndex<2>(data_format_, 'N')); - dims.set_feature_dimension(GetTensorDimIndex<2>(data_format_, 'C')); - dims.add_spatial_dimensions(GetTensorDimIndex<2>(data_format_, 'H')); - dims.add_spatial_dimensions(GetTensorDimIndex<2>(data_format_, 'W')); + std::vector window_strides; - // TF filter shape is [ H, W, inC, outC ] - dims.add_kernel_spatial_dimensions(0); - dims.add_kernel_spatial_dimensions(1); - dims.set_kernel_input_feature_dimension(2); - dims.set_kernel_output_feature_dimension(3); + dims.set_batch_dimension(GetTensorBatchDimIndex(num_dims(), data_format_)); + dims.set_feature_dimension(feature_dim); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dims.add_spatial_dimensions(input_dim); + dims.add_kernel_spatial_dimensions(i); + window_strides.push_back(strides_.at(input_dim)); + } + dims.set_kernel_input_feature_dimension(num_spatial_dims_); + dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); - std::vector window_strides = {stride_rows, stride_cols}; xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; @@ -133,38 +104,58 @@ class Conv2DOp : public XlaOpKernel { ctx->SetOutput(0, conv); } - private: + protected: + const int num_spatial_dims_; std::vector strides_; Padding padding_; - TensorFormat data_format_; + TensorFormat data_format_ = FORMAT_NHWC; - TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp); + private: + TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); }; -REGISTER_XLA_OP("Conv2D", Conv2DOp); - -// Backprop for input. -class Conv2DBackpropInputOp : public XlaOpKernel { +class Conv2DOp : public ConvOp { public: - explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + explicit Conv2DOp(OpKernelConstruction* ctx) + : ConvOp(ctx, /*num_spatial_dims=*/2) { string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("Conv2D"), Conv2DOp); + +class Conv3DOp : public ConvOp { + public: + explicit Conv3DOp(OpKernelConstruction* ctx) + : ConvOp(ctx, /*num_spatial_dims=*/3) {} +}; +REGISTER_XLA_OP(Name("Conv3D"), Conv3DOp); + +// Backprop for input. +class ConvBackpropInputOp : public XlaOpKernel { + public: + explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES(ctx, strides_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - int stride_n = GetTensorDim(strides_, data_format_, 'N'); - int stride_c = GetTensorDim(strides_, data_format_, 'C'); - OP_REQUIRES( - ctx, (stride_n == 1 && stride_c == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); } + int num_dims() const { return num_spatial_dims_ + 2; } + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES(ctx, strides_.size() == num_dims(), + errors::InvalidArgument("Sliding window strides field must " + "specify ", + num_dims(), " dimensions")); + int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); + int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); + OP_REQUIRES( + ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -172,10 +163,10 @@ class Conv2DBackpropInputOp : public XlaOpKernel { const TensorShape out_backprop_shape = ctx->InputShape(2); // Reuse dimension computation logic from conv_grad_ops.cc. - Conv2DBackpropDimensions dims; + ConvBackpropDimensions dims; OP_REQUIRES_OK( - ctx, Conv2DBackpropComputeDimensions( - "Conv2DBackpropInput", input_shape, filter_shape, + ctx, ConvBackpropComputeDimensions( + type_string(), num_spatial_dims_, input_shape, filter_shape, out_backprop_shape, strides_, padding_, data_format_, &dims)); auto filter = ctx->Input(1); @@ -186,73 +177,101 @@ class Conv2DBackpropInputOp : public XlaOpKernel { // comment at the top of conv_grad_ops.h for details. xla::ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(GetTensorDimIndex(data_format_, 'N')); - dnums.add_spatial_dimensions(GetTensorDimIndex(data_format_, 'H')); - dnums.add_spatial_dimensions(GetTensorDimIndex(data_format_, 'W')); - dnums.set_feature_dimension(GetTensorDimIndex(data_format_, 'C')); + dnums.set_batch_dimension(batch_dim); + dnums.set_feature_dimension(feature_dim); - // TF filter shape is [ H, W, inC, outC ] + // TF filter shape is [ H, W, ..., inC, outC ] // Transpose the input and output features for computing the gradient. - dnums.add_kernel_spatial_dimensions(0); - dnums.add_kernel_spatial_dimensions(1); - dnums.set_kernel_input_feature_dimension(3); - dnums.set_kernel_output_feature_dimension(2); + dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1); + dnums.set_kernel_output_feature_dimension(num_spatial_dims_); + + std::vector kernel_spatial_dims(num_spatial_dims_); + std::vector> padding(num_spatial_dims_); + std::vector lhs_dilation(num_spatial_dims_); + std::vector ones(num_spatial_dims_, 1); + for (int i = 0; i < num_spatial_dims_; ++i) { + dnums.add_spatial_dimensions( + GetTensorSpatialDimIndex(num_dims(), data_format_, i)); + dnums.add_kernel_spatial_dimensions(i); + + kernel_spatial_dims[i] = i; + padding[i] = {dims.spatial_dims[i].pad_before, + dims.spatial_dims[i].pad_after}; + lhs_dilation[i] = dims.spatial_dims[i].stride; + } // Mirror the filter in the spatial dimensions. xla::ComputationDataHandle mirrored_weights = - ctx->builder()->Rev(filter, {dnums.kernel_spatial_dimensions(0), - dnums.kernel_spatial_dimensions(1)}); + ctx->builder()->Rev(filter, kernel_spatial_dims); // activation gradients // = gradients (with padding and dilation) mirrored_weights xla::ComputationDataHandle in_backprop = ctx->builder()->ConvGeneralDilated( - out_backprop, mirrored_weights, /*window_strides=*/{1, 1}, - /*padding=*/{{dims.rows.pad_before, dims.rows.pad_after}, - {dims.cols.pad_before, dims.cols.pad_after}}, - /*lhs_dilation=*/{dims.rows.stride, dims.cols.stride}, - /*rhs_dilation=*/{1, 1}, dnums); + out_backprop, mirrored_weights, /*window_strides=*/ones, padding, + lhs_dilation, /*rhs_dilation=*/ones, dnums); ctx->SetOutput(0, in_backprop); } - private: + protected: + const int num_spatial_dims_; std::vector strides_; Padding padding_; - TensorFormat data_format_; + TensorFormat data_format_ = FORMAT_NHWC; - TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropInputOp); + private: + TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp); }; -class Conv2DBackpropFilterOp : public XlaOpKernel { +class Conv2DBackpropInputOp : public ConvBackpropInputOp { public: - explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx) { + explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) + : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2) { string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("Conv2DBackpropInput"), Conv2DBackpropInputOp); + +class Conv3DBackpropInputOp : public ConvBackpropInputOp { + public: + explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx) + : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3) {} +}; +REGISTER_XLA_OP(Name("Conv3DBackpropInputV2"), Conv3DBackpropInputOp); + +class ConvBackpropFilterOp : public XlaOpKernel { + public: + explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - int stride_n = GetTensorDim(strides_, data_format_, 'N'); - int stride_c = GetTensorDim(strides_, data_format_, 'C'); - OP_REQUIRES( - ctx, (stride_n == 1 && stride_c == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); } + int num_dims() const { return num_spatial_dims_ + 2; } + void Compile(XlaOpKernelContext* ctx) override { + const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_); + const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); + + OP_REQUIRES( + ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + const TensorShape activations_shape = ctx->InputShape(0); TensorShape filter_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); const TensorShape out_backprop_shape = ctx->InputShape(2); // Reuse dimension computation logic from conv_grad_ops.cc. - Conv2DBackpropDimensions dims; - OP_REQUIRES_OK( - ctx, Conv2DBackpropComputeDimensions( - "Conv2DBackpropFilter", activations_shape, filter_shape, - out_backprop_shape, strides_, padding_, data_format_, &dims)); + ConvBackpropDimensions dims; + OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( + type_string(), num_spatial_dims_, activations_shape, + filter_shape, out_backprop_shape, strides_, + padding_, data_format_, &dims)); xla::ComputationDataHandle activations = ctx->Input(0); xla::ComputationDataHandle gradients = ctx->Input(2); @@ -264,72 +283,71 @@ class Conv2DBackpropFilterOp : public XlaOpKernel { xla::ConvolutionDimensionNumbers dnums; // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, in_depth] + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] // For the gradient computation, we flip the roles of the batch and // feature dimensions. // Each spatial entry has size in_depth * batch - const int n_dim = GetTensorDimIndex(data_format_, 'N'); - const int h_dim = GetTensorDimIndex(data_format_, 'H'); - const int w_dim = GetTensorDimIndex(data_format_, 'W'); - const int c_dim = GetTensorDimIndex(data_format_, 'C'); // Swap n_dim and c_dim in the activations. dnums.set_batch_dimension(c_dim); - dnums.add_spatial_dimensions(h_dim); - dnums.add_spatial_dimensions(w_dim); dnums.set_feature_dimension(n_dim); // The gradients become the RHS of the convolution. - // The gradients have shape [batch, out_rows, out_cols, out_depth] where - // the batch becomes the input feature for the convolution. - dnums.add_kernel_spatial_dimensions(h_dim); - dnums.add_kernel_spatial_dimensions(w_dim); + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] + // where the batch becomes the input feature for the convolution. dnums.set_kernel_input_feature_dimension(n_dim); dnums.set_kernel_output_feature_dimension(c_dim); - // We will also need to pad the input with zeros such that after the - // convolution, we get the right size for the filter. - // The padded_in_rows should be such that when we convolve this with the - // expanded_out_rows as a filter, we should get filter_rows back. - // - const int padded_in_rows = - dims.rows.expanded_output_size + dims.rows.filter_size - 1; - const int padded_in_cols = - dims.cols.expanded_output_size + dims.cols.filter_size - 1; + std::vector> padding(num_spatial_dims_); + std::vector rhs_dilation(num_spatial_dims_); + std::vector ones(num_spatial_dims_, 1); - // However it can be smaller than input_rows: in this - // case it means some of the inputs are not used. - // - // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: - // - // INPUT = [ A B C ] - // - // FILTER = [ x y ] - // - // and the output will only have one column: a = A * x + B * y - // - // and input "C" is not used at all. - // - // We apply negative padding in this case. - const int total_pad_in_rows = padded_in_rows - dims.rows.input_size; - const int total_pad_in_cols = padded_in_cols - dims.cols.input_size; + for (int i = 0; i < num_spatial_dims_; ++i) { + int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dnums.add_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(dim); - // + For the VALID padding, we don't pad anything on the top/left side - // and pad the bottom/right side with the remaining space. - // + For the SAME padding, we pad top/left side the same as bottom/right - // side. - // - // In addition, if the padded input size is smaller than the input size, - // we need to ignore some training elements of the input. We do this by - // applying negative padding on the right/bottom. - const int top_pad_in_rows = - (total_pad_in_rows > 0 && padding_ == Padding::SAME) - ? total_pad_in_rows / 2 - : 0; - const int left_pad_in_cols = - (total_pad_in_cols > 0 && padding_ == Padding::SAME) - ? total_pad_in_cols / 2 - : 0; + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + // + const int padded_in_size = dims.spatial_dims[i].expanded_output_size + + dims.spatial_dims[i].filter_size - 1; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const int total_pad_in_size = + padded_in_size - dims.spatial_dims[i].input_size; + + // + For the VALID padding, we don't pad anything on the top/left side + // and pad the bottom/right side with the remaining space. + // + For the SAME padding, we pad top/left side the same as bottom/right + // side. + // + // In addition, if the padded input size is smaller than the input size, + // we need to ignore some training elements of the input. We do this by + // applying negative padding on the right/bottom. + const int before_pad_in_size = + (total_pad_in_size > 0 && padding_ == Padding::SAME) + ? total_pad_in_size / 2 + : 0; + + padding[i] = {before_pad_in_size, total_pad_in_size - before_pad_in_size}; + rhs_dilation[i] = dims.spatial_dims[i].stride; + } // Besides padding the input, we will also expand output_rows to // expanded_out_rows = (output_rows - 1) * stride + 1 @@ -341,33 +359,54 @@ class Conv2DBackpropFilterOp : public XlaOpKernel { // convolution HLO below. auto filter_backprop = ctx->builder()->ConvGeneralDilated( activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{top_pad_in_rows, total_pad_in_rows - top_pad_in_rows}, - {left_pad_in_cols, total_pad_in_cols - left_pad_in_cols}}, - /*lhs_dilation=*/{1, 1}, - /*rhs_dilation=*/{dims.rows.stride, dims.cols.stride}, dnums); + /*window_strides=*/ones, padding, /*lhs_dilation=*/ones, rhs_dilation, + dnums); // The layout of filter_backprop will match the layout of // padded_activations - // and so will have layout: [out_feature, h, w, in_feature] - // Tensorflow filter shape is [ H, W, inC, outC ], so we transpose the + // and so will have layout: [out_feature, h, w, ..., in_feature] + // Tensorflow filter shape is [ H, W, ..., inC, outC ], so we transpose the // output. + std::vector transpose_dims; + transpose_dims.reserve(num_dims()); + for (int i = 0; i < num_spatial_dims_; ++i) { + transpose_dims.push_back(dnums.spatial_dimensions(i)); + } + transpose_dims.push_back(c_dim); + transpose_dims.push_back(n_dim); xla::ComputationDataHandle filter_backprop_reshaped = - ctx->builder()->Transpose(filter_backprop, - {h_dim, w_dim, c_dim, n_dim}); + ctx->builder()->Transpose(filter_backprop, transpose_dims); ctx->SetOutput(0, filter_backprop_reshaped); } - private: + protected: + int num_spatial_dims_; std::vector strides_; Padding padding_; - TensorFormat data_format_; + TensorFormat data_format_ = FORMAT_NHWC; - TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropFilterOp); + private: + TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp); }; -REGISTER_XLA_OP("Conv2DBackpropInput", Conv2DBackpropInputOp); -REGISTER_XLA_OP("Conv2DBackpropFilter", Conv2DBackpropFilterOp); +class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { + public: + explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx) + : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("Conv2DBackpropFilter"), Conv2DBackpropFilterOp); + +class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { + public: + explicit Conv3DBackpropFilterOp(OpKernelConstruction* ctx) + : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3) {} +}; +REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2"), Conv3DBackpropFilterOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 3cd0b39c871..de93a88f064 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index f0687c1d4b5..ba38693325c 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -32,9 +32,7 @@ namespace tensorflow { // description of the operation; and Computation adds the // implementation of the operation to a xla::ComputationBuilder. For most // arithmetic Ops XLA handles the broadcasting automatically given the input -// tensors. Ops like ReluGrad that need to map a scalar function over the inputs -// can use the XlaBinaryMapOp subclass below which handles manual -// broadcasting of the inputs. +// tensors. class XlaBinaryOp : public XlaOpKernel { public: explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -83,6 +81,8 @@ class XlaBinaryOp : public XlaOpKernel { // virtual methods to override: description is a textual description // of the mapped function; and BuildMapLambda adds the // implementation of the lambda to a xla::ComputationBuilder. +// Operations may have better performance if implemented as graphs of +// element-wise tensor operations. class XlaBinaryMapOp : public XlaBinaryOp { public: explicit XlaBinaryMapOp(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} diff --git a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc b/tensorflow/compiler/tf2xla/kernels/declaration_op.cc deleted file mode 100644 index d96ff341789..00000000000 --- a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/core/framework/kernel_def_builder.h" - -namespace tensorflow { -namespace { - -// This OpKernel implements the Constant Op for XLA JIT -// devices. It extracts the constant Tensor from the Proto at kernel -// construction time, and then every time the Constant Op is executed -// an expression containing the constant is compiled. -class ConstantDeclarationOp : public XlaOpKernel { - public: - explicit ConstantDeclarationOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx), tensor_(ctx->output_type(0)) { - const TensorProto* proto = nullptr; - OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); - // MakeTensorFromProto uses the cpu_allocator, so tensor_ is a - // "real" tensor backed by CPU memory, holding the value of the - // constant. - OP_REQUIRES_OK(ctx, MakeTensorFromProto(*proto, &tensor_)); - OP_REQUIRES( - ctx, ctx->output_type(0) == tensor_.dtype(), - errors::InvalidArgument( - "Type mismatch between value (", DataTypeString(tensor_.dtype()), - ") and dtype (", DataTypeString(ctx->output_type(0)), ")")); - } - - void Compile(XlaOpKernelContext* ctx) override { - ctx->SetConstantOutput(0, tensor_); - } - - private: - // Extract the value of the constant from the Proto during Op kernel - // construction. The constant must be stored in a Tensor allocated - // using the cpu_allocator so that it is backed by real memory. The - // OpKernelConstruction's default allocator is the JITAllocator - // which only allocates enough space for metadata for each Tensor. - static Status MakeTensorFromProto(const TensorProto& tensor_proto, - Tensor* tensor) { - Tensor parsed(tensor_proto.dtype()); - if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - tensor_proto.DebugString()); - } - *tensor = parsed; - return Status::OK(); - } - - // This is a "real" tensor backed by CPU memory, containing the - // constant values. - Tensor tensor_; - TF_DISALLOW_COPY_AND_ASSIGN(ConstantDeclarationOp); -}; - -REGISTER_XLA_OP("Const", ConstantDeclarationOp); - -// This OpKernel implements the _Arg Op for XLA JIT devices. It -// associates its output with one of the arguments to a -// subcomputation. -class ArgOp : public XlaOpKernel { - public: - explicit ArgOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type_)); - } - - void Compile(XlaOpKernelContext* ctx) override { - // If 'frame' is non-null, this is a function call inside an outer JIT - // compilation. Use the usual implementation of _Arg. - auto frame = ctx->call_frame(); - if (frame != nullptr) { - Tensor val; - OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); - OP_REQUIRES(ctx, val.dtype() == dtype_, - errors::InvalidArgument( - "Type mismatch: actual ", DataTypeString(val.dtype()), - " vs. expect ", DataTypeString(dtype_))); - // Forwards the argument from the frame. - ctx->op_kernel_context()->set_output(0, val); - return; - } - - XlaContext& tc = XlaContext::Get(ctx); - - OP_REQUIRES(ctx, 0 <= index_ && index_ < tc.args().size(), - errors::InvalidArgument("Invalid argument index ", index_)); - const XlaCompiler::Argument& arg = tc.args()[index_]; - - if (arg.parameter < 0) { - ctx->SetConstantOutput(0, arg.constant_value); - } else { - ctx->SetOutput(0, tc.parameter(arg.parameter)); - } - } - - private: - int index_; - DataType dtype_; - xla::PrimitiveType type_; // Corresponding XLA type. - - TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); -}; - -REGISTER_XLA_OP("_Arg", ArgOp); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc index d408ab3338e..852d2a966ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -172,15 +172,14 @@ class DepthwiseConv2dNativeOp : public XlaOpKernel { } else { // These will be used to define the bounds of each slice. // Within the loop, the input_channel index will be modified. - gtl::InlinedVector filter_begin; - gtl::InlinedVector filter_limits; - gtl::InlinedVector input_begin; - gtl::InlinedVector input_limits; + gtl::InlinedVector filter_begin(4, 0); + gtl::InlinedVector filter_limits(4); + gtl::InlinedVector input_begin(4, 0); + gtl::InlinedVector input_limits(4); + gtl::InlinedVector strides(4, 1); for (int i = 0; i < 4; ++i) { - filter_begin.push_back(0); - filter_limits.push_back(filter_shape.dim_size(i)); - input_begin.push_back(0); - input_limits.push_back(input_shape.dim_size(i)); + filter_limits[i] = filter_shape.dim_size(i); + input_limits[i] = input_shape.dim_size(i); } std::vector strides_for_tla{strides_[1], strides_[2]}; @@ -209,9 +208,9 @@ class DepthwiseConv2dNativeOp : public XlaOpKernel { input_limits[3] = i + 1; xla::ComputationDataHandle filter_slice = - b.Slice(filter, filter_begin, filter_limits); + b.Slice(filter, filter_begin, filter_limits, strides); xla::ComputationDataHandle input_slice = - b.Slice(input, input_begin, input_limits); + b.Slice(input, input_begin, input_limits, strides); convs.push_back(b.ConvWithGeneralDimensions( input_slice, filter_slice, strides_for_tla, xla_padding, dims)); } @@ -229,7 +228,8 @@ class DepthwiseConv2dNativeOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp); }; -REGISTER_XLA_OP("DepthwiseConv2dNative", DepthwiseConv2dNativeOp); +REGISTER_XLA_OP(Name("DepthwiseConv2dNative").TypeConstraint("T", kFloatTypes), + DepthwiseConv2dNativeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index b89109ff6ab..ec5017f6ab9 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -64,7 +64,7 @@ class DiagOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("Diag", DiagOp); +REGISTER_XLA_OP(Name("Diag"), DiagOp); class DiagPartOp : public XlaOpKernel { public: @@ -125,14 +125,14 @@ class DiagPartOp : public XlaOpKernel { diag = builder->Reshape(diag, {new_size, new_size + 1}); // Slices out the first column and reshapes to the final shape. - diag = builder->Slice(diag, {0, 0}, {new_size, 1}); + diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); diag = builder->Reshape(diag, new_dims); ctx->SetOutput(0, diag); } }; -REGISTER_XLA_OP("DiagPart", DiagPartOp); +REGISTER_XLA_OP(Name("DiagPart"), DiagPartOp); class MatrixDiagOp : public XlaOpKernel { public: @@ -167,7 +167,7 @@ class MatrixDiagOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("MatrixDiag", MatrixDiagOp); +REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp); class MatrixDiagPartOp : public XlaOpKernel { public: @@ -224,8 +224,9 @@ class MatrixDiagPartOp : public XlaOpKernel { } else if (actual_size > target_size) { std::vector start(flattened_dims.size(), 0); std::vector limits(flattened_dims.begin(), flattened_dims.end()); + std::vector strides(flattened_dims.size(), 1); limits[flattened_dims.size() - 1] = target_size; - diag = builder->Slice(diag, start, limits); + diag = builder->Slice(diag, start, limits, strides); } // Reshape so the target values are in the first position of the last @@ -238,8 +239,9 @@ class MatrixDiagPartOp : public XlaOpKernel { // Slices out the first column and reshapes to the final shape. std::vector start(dims.size(), 0); std::vector limits(dims.begin(), dims.end()); + std::vector strides(dims.size(), 1); limits[last_dim] = 1; - diag = builder->Slice(diag, start, limits); + diag = builder->Slice(diag, start, limits, strides); // Collapses away the last dimension. dims.pop_back(); @@ -249,7 +251,7 @@ class MatrixDiagPartOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("MatrixDiagPart", MatrixDiagPartOp); +REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 2936e792619..107c673f4a7 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -157,6 +157,8 @@ class DynamicStitchOp : public XlaOpKernel { indices0_shape.dims()); std::vector slice_limit(1 + data0_shape.dims() - indices0_shape.dims()); + std::vector stride(1 + data0_shape.dims() - + indices0_shape.dims(), 1); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } @@ -169,7 +171,7 @@ class DynamicStitchOp : public XlaOpKernel { // And place it in the concat list in the place indicated by // the index. to_concat[index_num] = - ctx->builder()->Slice(expression, slice_start, slice_limit); + ctx->builder()->Slice(expression, slice_start, slice_limit, stride); } ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0)); @@ -194,7 +196,7 @@ class DynamicStitchOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("DynamicStitch", DynamicStitchOp); +REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc new file mode 100644 index 00000000000..62a5e1bd421 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Native XLA implementations of XLA Elu Ops + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class EluOp : public XlaOpKernel { + public: + explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto pred = b->Gt(ctx->Input(0), zero); + const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); + } +}; + +class EluGradOp : public XlaOpKernel { + public: + explicit EluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return lhs * (1 + rhs). + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto grad = ctx->Input(0); + const auto activation = ctx->Input(1); + const auto exp_grad = b->Mul(grad, b->Add(activation, one)); + const auto pred = b->Gt(activation, zero); + ctx->SetOutput(0, b->Select(pred, grad, exp_grad)); + } +}; + +REGISTER_XLA_OP(Name("Elu"), EluOp); +REGISTER_XLA_OP(Name("EluGrad"), EluGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 918c80aad8c..1e1d2a1b4b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -16,9 +16,9 @@ limitations under the License. // XLA-specific Fill Op. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -50,6 +50,7 @@ class FillOp : public XlaOpKernel { // Convert the dims literal into a vector that we can pass to // ComputationBuilder. std::vector broadcast; + broadcast.reserve(dims_literal.shape().dimensions(0)); for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { broadcast.push_back(xla::LiteralUtil::Get(dims_literal, {i})); } @@ -68,7 +69,7 @@ class FillOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("Fill", FillOp); +REGISTER_XLA_OP(Name("Fill"), FillOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index 53f2196dc59..8dacb6627bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -47,8 +47,8 @@ class PassOn : public XlaOpKernel { } }; -REGISTER_XLA_OP("_ListToArray", PassOn); -REGISTER_XLA_OP("_ArrayToList", PassOn); +REGISTER_XLA_OP(Name("_ListToArray"), PassOn); +REGISTER_XLA_OP(Name("_ArrayToList"), PassOn); // TODO(phawkins): this is an almost exact copy of the SymbolicGradientOp // implementation from regular Tensorflow. Once XLA has been open sourced @@ -68,7 +68,8 @@ class SymbolicGradientOp : public AsyncOpKernel { done); OP_REQUIRES_OK_ASYNC( - ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done); + ctx, lib->Instantiate(kGradientOp, AttrSlice(&def().attr()), &handle_), + done); FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); @@ -104,7 +105,7 @@ class SymbolicGradientOp : public AsyncOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp); }; -REGISTER_XLA_OP(kGradientOp, SymbolicGradientOp); +REGISTER_XLA_OP(Name(kGradientOp), SymbolicGradientOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index b98d3864790..49eadaf9d1f 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -93,12 +93,10 @@ class GatherOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(GatherOp); }; -REGISTER_XLA_OP("Gather", GatherOp); - -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Gather") - .TypeConstraint("Tparams", DT_FLOAT) - .TypeConstraint("Tindices", {DT_INT32, DT_INT64})); +REGISTER_XLA_OP(Name("Gather") + .TypeConstraint("Tparams", DT_FLOAT) + .Device(DEVICE_CPU_XLA_JIT), + GatherOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc index eff23bd77d2..691a0b972d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,6 @@ EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) -gather_float_int32_xla_impl(float* out, void** data) { +extern "C" void TF_EXPORT gather_float_int32_xla_impl(float* out, void** data) { tensorflow::gather_float_int32_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc index ae31f6f2006..3dff6e2737b 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,6 @@ EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) -gather_float_int64_xla_impl(float* out, void** data) { +extern "C" void TF_EXPORT gather_float_int64_xla_impl(float* out, void** data) { tensorflow::gather_float_int64_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 01417a3cdf7..87d3d64a4e9 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" namespace tensorflow { namespace { @@ -31,9 +31,12 @@ class IdentityOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(IdentityOp); }; -REGISTER_XLA_OP("Identity", IdentityOp); -REGISTER_XLA_OP("PreventGradient", IdentityOp); -REGISTER_XLA_OP("StopGradient", IdentityOp); +// XLA_* devices also register a "real" Identity operator so we suppress the +// dummy operator using CompilationOnly(). +REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); + +REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); +REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 293705e39fc..df002dddd04 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -16,9 +16,9 @@ limitations under the License. // Native XLA implementations of indexing ops. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -134,9 +134,9 @@ class ArgMaxOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxOp); }; -REGISTER_XLA_OP("ArgMax", ArgMaxOp); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("ArgMax").TypeConstraint("T", DT_FLOAT)); +REGISTER_XLA_OP( + Name("ArgMax").TypeConstraint("T", DT_FLOAT).Device(DEVICE_CPU_XLA_JIT), + ArgMaxOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 0033a949a37..afbd64ca503 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -43,7 +44,6 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) -argmax_float_1d_xla_impl(void* out, void** data) { +extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index be8ad2317c9..841ff2f4df7 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -45,7 +46,6 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) -argmax_float_2d_xla_impl(void* out, void** data) { +extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 248984bcfec..d096415087e 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -47,7 +47,7 @@ class L2LossOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("L2Loss", L2LossOp); +REGISTER_XLA_OP(Name("L2Loss"), L2LossOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 93966d3d5a9..759d1a1a2d9 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -69,7 +69,7 @@ class LRNOp : public XlaOpKernel { float beta_; }; -REGISTER_XLA_OP("LRN", LRNOp); +REGISTER_XLA_OP(Name("LRN"), LRNOp); class LRNGradOp : public XlaOpKernel { public: @@ -167,7 +167,7 @@ class LRNGradOp : public XlaOpKernel { float beta_; }; -REGISTER_XLA_OP("LRNGrad", LRNGradOp); +REGISTER_XLA_OP(Name("LRNGrad"), LRNGradOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 5af6a79f3e4..5c799a0e4f8 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -15,9 +15,9 @@ limitations under the License. // XLA-specific MatMul Op. -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -73,7 +73,7 @@ class MatMulOp : public XlaOpKernel { bool transpose_b_; }; -REGISTER_XLA_OP("MatMul", MatMulOp); +REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kFloatTypes), MatMulOp); class SparseMatMulOp : public MatMulOp { public: @@ -82,7 +82,10 @@ class SparseMatMulOp : public MatMulOp { ~SparseMatMulOp() override = default; }; -REGISTER_XLA_OP("SparseMatMul", SparseMatMulOp); +REGISTER_XLA_OP(Name("SparseMatMul") + .TypeConstraint("Ta", kFloatTypes) + .TypeConstraint("Tb", kFloatTypes), + SparseMatMulOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index 806bfc604f7..b8f0c0b9fe6 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/no_op.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { -REGISTER_XLA_OP("NoOp", NoOp); +// XLA_* devices also register a "real" NoOp operator so we suppress the +// dummy operator using CompilationOnly(). +REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc new file mode 100644 index 00000000000..2a9cfcb2eb8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA implementation of OneHot operator. + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class OneHotOp : public XlaOpKernel { + public: + explicit OneHotOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape indices_shape = ctx->InputShape(0); + const TensorShape depth_shape = ctx->InputShape(1); + const TensorShape on_value_shape = ctx->InputShape(2); + const TensorShape off_value_shape = ctx->InputShape(3); + + const int indices_dims = indices_shape.dims(); + const int output_dims = indices_dims + 1; + + // Preliminary validation of sizes. + OP_REQUIRES( + ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims), + errors::InvalidArgument("Expected axis to be -1 or between [0, ", + output_dims, "). But received: ", axis_)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth_shape), + errors::InvalidArgument("depth must be a scalar, but got: ", + depth_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value_shape), + errors::InvalidArgument("on_value must be a scalar, but got: ", + on_value_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value_shape), + errors::InvalidArgument("off_value must be a scalar, but got: ", + off_value_shape.DebugString())); + + const int axis = (axis_ == -1) ? indices_dims : axis_; + + // The one-hot dimension. + int64 depth; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &depth)); + OP_REQUIRES( + ctx, depth >= 0, + errors::InvalidArgument("depth must be non-negative, got: ", depth)); + + xla::ComputationDataHandle one_hot; + OP_REQUIRES_OK( + ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0), + indices_shape, ctx->Input(0), ctx->Input(2), + ctx->Input(3), &one_hot)); + ctx->SetOutput(0, one_hot); + } + + private: + int32 axis_; + + TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); +}; + +REGISTER_XLA_OP(Name("OneHot"), OneHotOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index 7456d92de03..a4318e29d25 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -87,7 +87,7 @@ class PackOp : public XlaOpKernel { int axis_; }; -REGISTER_XLA_OP("Pack", PackOp); +REGISTER_XLA_OP(Name("Pack"), PackOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 2846414c5ec..22476f4a0c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -74,7 +74,7 @@ class PadOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("Pad", PadOp); +REGISTER_XLA_OP(Name("Pad"), PadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 7a1ce2db85c..2b6053d19dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -16,9 +16,9 @@ limitations under the License. // XLA specific pooling ops. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" @@ -35,19 +35,21 @@ namespace { // Superclass of pooling ops. class PoolingOp : public XlaOpKernel { public: - explicit PoolingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - // Data format doesn't matter since the kernel is specified explicitly. + PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { std::vector ksize_int; std::vector stride_int; OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); - OP_REQUIRES(ctx, ksize_int.size() == 4, + OP_REQUIRES(ctx, ksize_int.size() == num_dims(), errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); + "specify ", + num_dims(), " dimensions")); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int)); - OP_REQUIRES(ctx, stride_int.size() == 4, + OP_REQUIRES(ctx, stride_int.size() == num_dims(), errors::InvalidArgument("Sliding window stride field must " - "specify 4 dimensions")); - for (int i = 0; i < 4; ++i) { + "specify ", + num_dims(), " dimensions")); + for (int i = 0; i < num_dims(); ++i) { ksize_.push_back(ksize_int[i]); stride_.push_back(stride_int[i]); } @@ -56,6 +58,8 @@ class PoolingOp : public XlaOpKernel { padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; } + int num_dims() const { return num_spatial_dims_ + 2; } + // Method that builds an initial value to use in reductions. virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, DataType data_type) = 0; @@ -73,6 +77,11 @@ class PoolingOp : public XlaOpKernel { xla::ComputationDataHandle input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); + const DataType type = input_type(0); xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize_, @@ -81,14 +90,17 @@ class PoolingOp : public XlaOpKernel { } protected: + const int num_spatial_dims_; std::vector ksize_; std::vector stride_; xla::Padding padding_; + TensorFormat data_format_ = FORMAT_NHWC; }; class MaxPoolOp : public PoolingOp { public: - explicit MaxPoolOp(OpKernelConstruction* ctx) : PoolingOp(ctx) {} + MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) + : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {} xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, DataType data_type) override { @@ -107,7 +119,24 @@ class MaxPoolOp : public PoolingOp { } }; -REGISTER_XLA_OP("MaxPool", MaxPoolOp); +class MaxPool2DOp : public MaxPoolOp { + public: + explicit MaxPool2DOp(OpKernelConstruction* ctx) + : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); + +class MaxPool3DOp : public MaxPoolOp { + public: + explicit MaxPool3DOp(OpKernelConstruction* ctx) + : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {} +}; +REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); // Common computation shared between AvgPool and AvgPoolGrad. Divide each // element of an image by the count of elements that contributed to that @@ -116,7 +145,7 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, DataType dtype, const TensorShape& input_shape, xla::Padding padding, const std::vector& ksize, const std::vector& stride, - TensorFormat data_format) { + int num_spatial_dims, TensorFormat data_format) { if (padding == xla::Padding::kValid) { // In VALID padding, all windows have the same number of elements // contributing to each average. Divide by the window size everywhere to @@ -134,34 +163,37 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( // TODO(phawkins): use a less brute-force way to compute this. Only // the boundary regions will have interesting values here. - int height_dim = GetTensorDimIndex(data_format, 'H'); - int width_dim = GetTensorDimIndex(data_format, 'W'); - CHECK_LT(height_dim, width_dim); + std::vector input_dim_sizes(num_spatial_dims); + std::vector window_dims(num_spatial_dims); + std::vector window_ksize(num_spatial_dims); + std::vector window_stride(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i); + input_dim_sizes[i] = input_shape.dim_size(dim); + window_dims[i] = dim; + window_ksize[i] = ksize[dim]; + window_stride[i] = stride[dim]; + } // Build a matrix of all 1s, with the same width/height as the input. auto ones = ctx->builder()->Broadcast( - XlaHelpers::One(ctx->builder(), dtype), - {input_shape.dim_size(height_dim), input_shape.dim_size(width_dim)}); + XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes); // Perform a ReduceWindow with the same window size, strides, and padding // to count the number of contributions to each result element. auto counts = ctx->builder()->ReduceWindow( ones, XlaHelpers::Zero(ctx->builder(), dtype), - *ctx->GetOrCreateAdd(dtype), {ksize[height_dim], ksize[width_dim]}, - {stride[height_dim], stride[width_dim]}, xla::Padding::kSame); + *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride, + xla::Padding::kSame); - return ctx->builder()->Div(output, counts, {height_dim, width_dim}); + return ctx->builder()->Div(output, counts, window_dims); } } class AvgPoolOp : public PoolingOp { public: - explicit AvgPoolOp(OpKernelConstruction* ctx) : PoolingOp(ctx) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); - } + AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) + : PoolingOp(ctx, num_spatial_dims) {} xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, DataType data_type) override { @@ -177,14 +209,29 @@ class AvgPoolOp : public PoolingOp { XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, DataType dtype, const TensorShape& input_shape) override { return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, - ksize_, stride_, data_format_); + ksize_, stride_, num_spatial_dims_, + data_format_); } - - private: - TensorFormat data_format_; }; -REGISTER_XLA_OP("AvgPool", AvgPoolOp); +class AvgPool2DOp : public AvgPoolOp { + public: + explicit AvgPool2DOp(OpKernelConstruction* ctx) + : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); + +class AvgPool3DOp : public AvgPoolOp { + public: + explicit AvgPool3DOp(OpKernelConstruction* ctx) + : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {} +}; +REGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp); // The operation to compute MaxPool gradients. // It takes three inputs: @@ -194,35 +241,39 @@ REGISTER_XLA_OP("AvgPool", AvgPoolOp); // It produces one output: backprop tensor for input. class MaxPoolGradOp : public XlaOpKernel { public: - explicit MaxPoolGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); - OP_REQUIRES(ctx, ksize_.size() == 4, + OP_REQUIRES(ctx, ksize_.size() == num_dims(), errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); + "specify ", + num_dims(), " dimensions")); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); - OP_REQUIRES(ctx, stride_.size() == 4, + OP_REQUIRES(ctx, stride_.size() == num_dims(), errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); + "specify ", + num_dims(), " dimensions")); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); } + int num_dims() const { return num_spatial_dims_ + 2; } + void Compile(XlaOpKernelContext* ctx) override { const TensorShape tensor_in_shape = ctx->InputShape(0); const TensorShape tensor_out_shape = ctx->InputShape(1); const TensorShape out_backprop_shape = ctx->InputShape(2); - // For maxpooling, tensor_in should have 4 dimensions. - OP_REQUIRES(ctx, tensor_in_shape.dims() == 4, - errors::InvalidArgument("tensor_in must be 4-dimensional")); - OP_REQUIRES(ctx, tensor_out_shape.dims() == 4, - errors::InvalidArgument("tensor_out must be 4-dimensional")); - // For maxpooling, out_backprop should have 4 dimensions. - OP_REQUIRES(ctx, out_backprop_shape.dims() == 4, - errors::InvalidArgument("out_backprop must be 4-dimensional")); + // For maxpooling, tensor_in should have num_dims() dimensions. + OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_in must be ", num_dims(), + "-dimensional")); + OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_out must be ", num_dims(), + "-dimensional")); + // For maxpooling, out_backprop should have num_dims() dimensions. + OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), + errors::InvalidArgument("out_backprop must be ", num_dims(), + "-dimensional")); // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate // whether this is a good time/space tradeoff. @@ -245,55 +296,74 @@ class MaxPoolGradOp : public XlaOpKernel { ctx->SetOutput(0, gradients); } - private: + protected: + const int num_spatial_dims_; std::vector ksize_; std::vector stride_; Padding padding_; - TensorFormat data_format_; + TensorFormat data_format_ = FORMAT_NHWC; }; -REGISTER_XLA_OP("MaxPoolGrad", MaxPoolGradOp); - -// Average-pooling gradient -class AvgPoolGradOp : public XlaOpKernel { +class MaxPool2DGradOp : public MaxPoolGradOp { public: - explicit AvgPoolGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + explicit MaxPool2DGradOp(OpKernelConstruction* ctx) + : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) { string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp); + +class MaxPool3DGradOp : public MaxPoolGradOp { + public: + explicit MaxPool3DGradOp(OpKernelConstruction* ctx) + : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {} +}; +REGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp); + +// Average-pooling gradient +class AvgPoolGradOp : public XlaOpKernel { + public: + AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); - OP_REQUIRES(ctx, ksize_.size() == 4, + OP_REQUIRES(ctx, ksize_.size() == num_dims(), errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); + "specify ", + num_dims(), " dimensions")); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); - OP_REQUIRES(ctx, stride_.size() == 4, + OP_REQUIRES(ctx, stride_.size() == num_dims(), errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); + "specify ", + num_dims(), " dimensions")); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); } + int num_dims() const { return num_spatial_dims_ + 2; } + void Compile(XlaOpKernelContext* ctx) override { TensorShape gradients_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape)); const TensorShape out_backprop_shape = ctx->InputShape(1); - // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements. - OP_REQUIRES( - ctx, gradients_shape.dims() == 4, - errors::InvalidArgument("orig_input_shape must have 4 elements")); + // For avgpooling, tensor_in_shape should have num_dims() dimensions. + OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(), + errors::InvalidArgument("orig_input_shape must be ", num_dims(), + "-dimensional")); - // For avgpooling, out_backprop should have 4 dimensions. - OP_REQUIRES(ctx, out_backprop_shape.dims() == 4, - errors::InvalidArgument("out_backprop must be 4-dimensional")); + // For avgpooling, out_backprop should have num_dims() dimensions. + OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), + errors::InvalidArgument("out_backprop must be ", num_dims(), + "-dimensional")); - int height_dim = GetTensorDimIndex(data_format_, 'H'); - int width_dim = GetTensorDimIndex(data_format_, 'W'); - int depth = GetTensorDim(out_backprop_shape, data_format_, 'C'); + int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); + int64 depth = out_backprop_shape.dim_size(depth_dim); // We can think of average-pooling as: // * a convolution with a kernel consisting entirely of 1s, where the @@ -308,16 +378,23 @@ class AvgPoolGradOp : public XlaOpKernel { // For an explanation of backpropagation for convolution, see the comments // in third_party/tensorflow/core/kernels/conv_grad_ops.h - // TF filter shape is [ H, W, inC, outC ] - TensorShape filter_shape( - {ksize_[height_dim], ksize_[width_dim], depth, depth}); + // TF filter shape is [ H, W, ..., inC, outC ] + std::vector filter_dims(num_dims()); + for (int i = 0; i < num_spatial_dims_; ++i) { + int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + filter_dims[i] = ksize_[dim]; + } + filter_dims[num_dims() - 2] = depth; + filter_dims[num_dims() - 1] = depth; + TensorShape filter_shape(filter_dims); // Reuse the logic from Conv2DBackpropInput to compute padding. - Conv2DBackpropDimensions dims; + ConvBackpropDimensions dims; OP_REQUIRES_OK( - ctx, Conv2DBackpropComputeDimensions( - "AvgPoolGrad", gradients_shape, filter_shape, - out_backprop_shape, stride_, padding_, data_format_, &dims)); + ctx, ConvBackpropComputeDimensions( + type_string(), /*num_spatial_dims=*/num_spatial_dims_, + gradients_shape, filter_shape, out_backprop_shape, stride_, + padding_, data_format_, &dims)); auto out_backprop = ctx->Input(1); @@ -332,43 +409,60 @@ class AvgPoolGradOp : public XlaOpKernel { // Divide the out_backprop values by the counts for each spatial position. std::vector stride_int64s(stride_.begin(), stride_.end()); - auto out_backprop_div = - AvgPoolDivideByCount(ctx, out_backprop, dtype, gradients_shape, - xla_padding, ksize_, stride_int64s, data_format_); + auto out_backprop_div = AvgPoolDivideByCount( + ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_, + stride_int64s, num_spatial_dims_, data_format_); // Pad the gradients in the spatial dimensions. We use the same padding // as Conv2DBackpropInput. - xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(4); - auto* row_padding = padding_config.mutable_dimensions(height_dim); - row_padding->set_edge_padding_low(dims.rows.pad_before); - row_padding->set_edge_padding_high(dims.rows.pad_after); - row_padding->set_interior_padding(dims.rows.stride - 1); - - auto* col_padding = padding_config.mutable_dimensions(width_dim); - col_padding->set_edge_padding_low(dims.cols.pad_before); - col_padding->set_edge_padding_high(dims.cols.pad_after); - col_padding->set_interior_padding(dims.cols.stride - 1); + xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims()); + for (int i = 0; i < num_spatial_dims_; ++i) { + int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + auto* padding = padding_config.mutable_dimensions(dim); + padding->set_edge_padding_low(dims.spatial_dims[i].pad_before); + padding->set_edge_padding_high(dims.spatial_dims[i].pad_after); + padding->set_interior_padding(dims.spatial_dims[i].stride - 1); + } auto zero = XlaHelpers::Zero(ctx->builder(), dtype); auto padded_gradients = ctx->builder()->Pad(out_backprop_div, zero, padding_config); // in_backprop = padded_gradients ones + std::vector ones(num_dims(), 1LL); xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow( padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_, - /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kValid); + /* window_strides=*/ones, xla::Padding::kValid); ctx->SetOutput(0, in_backprop); } - private: + protected: + const int num_spatial_dims_; std::vector ksize_; std::vector stride_; Padding padding_; - TensorFormat data_format_; + TensorFormat data_format_ = FORMAT_NHWC; }; -REGISTER_XLA_OP("AvgPoolGrad", AvgPoolGradOp); +class AvgPool2DGradOp : public AvgPoolGradOp { + public: + explicit AvgPool2DGradOp(OpKernelConstruction* ctx) + : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("AvgPoolGrad"), AvgPool2DGradOp); + +class AvgPool3DGradOp : public AvgPoolGradOp { + public: + explicit AvgPool3DGradOp(OpKernelConstruction* ctx) + : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {} +}; +REGISTER_XLA_OP(Name("AvgPool3DGrad"), AvgPool3DGradOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 4ffe278d1c4..66b99665cbe 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -18,9 +18,10 @@ limitations under the License. // TODO(misard,phawkins): add tests. #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -51,7 +52,7 @@ class RandomUniformOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp); }; -REGISTER_XLA_OP("RandomUniform", RandomUniformOp); +REGISTER_XLA_OP(Name("RandomUniform"), RandomUniformOp); class RandomUniformIntOp : public XlaOpKernel { public: @@ -82,7 +83,7 @@ class RandomUniformIntOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp); }; -REGISTER_XLA_OP("RandomUniformInt", RandomUniformIntOp); +REGISTER_XLA_OP(Name("RandomUniformInt"), RandomUniformIntOp); class RandomStandardNormalOp : public XlaOpKernel { public: @@ -110,7 +111,79 @@ class RandomStandardNormalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp); }; -REGISTER_XLA_OP("RandomStandardNormal", RandomStandardNormalOp); +REGISTER_XLA_OP(Name("RandomStandardNormal"), RandomStandardNormalOp); + +class TruncatedNormalOp : public XlaOpKernel { + public: + explicit TruncatedNormalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const DataType dtype = output_type(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); + xla::Shape xla_element_shape = + xla::ShapeUtil::MakeShape(xla_shape.element_type(), {}); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); + xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); + xla::ComputationDataHandle candidate = + b->RngNormal(mean, stddev, xla_shape); + + auto two_sd = [dtype](bool negate, xla::ComputationBuilder* b) { + return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); + }; + auto out_of_range_mask = [two_sd](xla::ComputationDataHandle candidate, + xla::ComputationBuilder* b) { + xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b)); + xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b)); + return b->LogicalOr(too_large, too_small); + }; + + // The algorithm we're using is roughly: + // + // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) { + // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd + // candidate = select(out_of_range_mask, rng_normal(), candidate) + // } + std::unique_ptr test_builder = + b->CreateSubBuilder("truncated_normal_test"); + { + auto* b = test_builder.get(); + xla::ComputationDataHandle candidate = + b->Parameter(0, xla_shape, "candidate"); + xla::ComputationDataHandle oor_mask = out_of_range_mask(candidate, b); + OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status()); + } + + std::unique_ptr body_builder = + b->CreateSubBuilder("truncated_normal_body"); + { + auto* b = body_builder.get(); + xla::ComputationDataHandle candidate = + b->Parameter(0, xla_shape, "candidate"); + xla::ComputationDataHandle to_resample = out_of_range_mask(candidate, b); + xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); + xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); + b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate); + } + + xla::StatusOr test_computation = test_builder->Build(); + OP_REQUIRES_OK(ctx, test_computation.status()); + xla::StatusOr body_computation = body_builder->Build(); + OP_REQUIRES_OK(ctx, body_computation.status()); + xla::ComputationDataHandle result = + b->While(test_computation.ValueOrDie(), body_computation.ValueOrDie(), + candidate); + + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("TruncatedNormal"), TruncatedNormalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index ac929af2e2b..518a9372c4f 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -35,7 +35,7 @@ class SumOp : public XlaReductionOp { } }; -REGISTER_XLA_OP("Sum", SumOp); +REGISTER_XLA_OP(Name("Sum"), SumOp); class ProdOp : public XlaReductionOp { public: @@ -53,7 +53,7 @@ class ProdOp : public XlaReductionOp { } }; -REGISTER_XLA_OP("Prod", ProdOp); +REGISTER_XLA_OP(Name("Prod"), ProdOp); class MinOp : public XlaReductionOp { public: @@ -73,7 +73,7 @@ class MinOp : public XlaReductionOp { } }; -REGISTER_XLA_OP("Min", MinOp); +REGISTER_XLA_OP(Name("Min"), MinOp); class MaxOp : public XlaReductionOp { public: @@ -93,7 +93,7 @@ class MaxOp : public XlaReductionOp { } }; -REGISTER_XLA_OP("Max", MaxOp); +REGISTER_XLA_OP(Name("Max"), MaxOp); class MeanOp : public XlaReductionOp { public: @@ -105,17 +105,17 @@ class MeanOp : public XlaReductionOp { builder->Add(scalar_lhs, scalar_rhs); } - bool BuildFinalizer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_argument, - int64 num_elements_reduced) override { + xla::ComputationDataHandle BuildFinalizer( + xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& reduce_output, + int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); - builder->Div(scalar_argument, divisor); - return true; + return builder->Div(reduce_output, divisor); } }; -REGISTER_XLA_OP("Mean", MeanOp); +REGISTER_XLA_OP(Name("Mean"), MeanOp); class AllOp : public XlaReductionOp { public: @@ -133,7 +133,7 @@ class AllOp : public XlaReductionOp { } }; -REGISTER_XLA_OP("All", AllOp); +REGISTER_XLA_OP(Name("All"), AllOp); class AnyOp : public XlaReductionOp { public: @@ -151,7 +151,7 @@ class AnyOp : public XlaReductionOp { } }; -REGISTER_XLA_OP("Any", AnyOp); +REGISTER_XLA_OP(Name("Any"), AnyOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 7f0dd26f914..9aca6d8fedf 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -48,16 +48,15 @@ class XlaReductionOp : public XlaOpKernel { const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) = 0; - // Implement the scalar->scalar lambda that should be applied to - // each element to be finalized. The desired computation should be - // added to 'builder' and 'scalar_argument' is the function's - // input. 'num_elements_reduced' is the number of elements that contributed - // to the reduction. If the reduction has a finalizer return true, otherwise - // return false and any computation added to builder will be - // ignored. Defaults to return false. - virtual bool BuildFinalizer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_argument, - int64 num_elements_reduced); + // Applies a transformation to the output of the reduction. The desired + // computation should be added to 'builder'. Argument 'reduce_output' is the + // output of the reduction. 'num_elements_reduced' is the number of elements + // that contributed to the reduction. Returns the transformed reduction + // output, Defaults to returning 'reduce_output' unchanged. + virtual xla::ComputationDataHandle BuildFinalizer( + xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& reduce_output, + int64 num_elements_reduced); void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index d6b085e8978..8798c80ad53 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -39,11 +39,11 @@ xla::ComputationDataHandle XlaReductionOp::InitialValue( // Unless BuildFinalizer is overridden the reduction has no // finalizer. -bool XlaReductionOp::BuildFinalizer( +xla::ComputationDataHandle XlaReductionOp::BuildFinalizer( xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_argument, + const xla::ComputationDataHandle& reduce_output, int64 num_elements_reduced) { - return false; + return reduce_output; } void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { @@ -121,28 +121,14 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::ComputationDataHandle reduce = ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes); - // Construct the builder for the finalizer lambda. - xla::ComputationBuilder f(ctx->builder()->client(), - strings::StrCat(desc, "-finalizer")); - // Make the scalar parameter of the desired type for the lambda. - xla::ComputationDataHandle fx = - f.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - // Call virtual method to build the finalizer lambda. - bool has_finalizer = BuildFinalizer(&f, fx, num_elements_reduced); - xla::Computation finalizer_computation = f.Build().ConsumeValueOrDie(); - xla::ComputationDataHandle pre_reshaped_data; - if (has_finalizer) { - // This reduction Op includes a finalizer so run it as a Map. - pre_reshaped_data = ctx->builder()->Map({reduce}, finalizer_computation); - } else { - pre_reshaped_data = reduce; - } + xla::ComputationDataHandle finalized = + BuildFinalizer(ctx->builder(), reduce, num_elements_reduced); xla::ComputationDataHandle result; if (keep_dims_) { - result = ctx->builder()->Reshape(pre_reshaped_data, final_shape); + result = ctx->builder()->Reshape(finalized, final_shape); } else { - result = pre_reshaped_data; + result = finalized; } ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index 8adac23eeec..a137d28118e 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -16,8 +16,8 @@ limitations under the License. // Native XLA implementations of XLA Relu Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -31,7 +31,7 @@ class ReluOp : public XlaOpKernel { public: explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. - void Compile(XlaOpKernelContext* ctx) { + void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); ctx->SetOutput(0, builder->Max(zero, ctx->Input(0))); @@ -42,7 +42,7 @@ class Relu6Op : public XlaOpKernel { public: explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Clamp the scalar input between 0 and 6. - void Compile(XlaOpKernelContext* ctx) { + void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6); @@ -50,43 +50,44 @@ class Relu6Op : public XlaOpKernel { } }; -// A subclass of a XlaBinaryMapOp must build the lambda computation -// that describes the (scalar,scalar)->scalar function to apply to -// each element of the input. We have to use XlaBinaryMapOp instead of -// XlaBinaryOp here because XLA Select does not do automatic -// broadcasting. -class ReluGradOp : public XlaBinaryMapOp { +class ReluGradOp : public XlaOpKernel { public: - explicit ReluGradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} + explicit ReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. - void BuildMapLambda(xla::ComputationBuilder* b, - const xla::ComputationDataHandle& gradient, - const xla::ComputationDataHandle& feature) override { - const auto zero = XlaHelpers::Zero(b, input_type(0)); - b->Select(b->Gt(feature, zero), gradient, zero); + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const TensorShape shape = ctx->InputShape(0); + const auto zero = + b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto pred = b->Gt(ctx->Input(1), zero); + ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero)); } }; -class Relu6GradOp : public XlaBinaryMapOp { +class Relu6GradOp : public XlaOpKernel { public: - explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} + explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. - void BuildMapLambda(xla::ComputationBuilder* b, - const xla::ComputationDataHandle& gradient, - const xla::ComputationDataHandle& feature) override { - const auto zero = XlaHelpers::Zero(b, input_type(0)); - auto six = XlaHelpers::IntegerLiteral(b, input_type(0), 6); - b->Select(b->LogicalAnd(b->Lt(feature, six), b->Gt(feature, zero)), - gradient, zero); + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const TensorShape shape = ctx->InputShape(0); + const auto zero = + b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto six = b->Broadcast( + XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); + auto out = b->Select( + b->LogicalAnd(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); + ctx->SetOutput(0, out); } }; -REGISTER_XLA_OP("Relu", ReluOp); -REGISTER_XLA_OP("Relu6", Relu6Op); -REGISTER_XLA_OP("ReluGrad", ReluGradOp); -REGISTER_XLA_OP("Relu6Grad", Relu6GradOp); +REGISTER_XLA_OP(Name("Relu"), ReluOp); +REGISTER_XLA_OP(Name("Relu6"), Relu6Op); +REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp); +REGISTER_XLA_OP(Name("Relu6Grad"), Relu6GradOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index febce0e1267..df542350b44 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -16,9 +16,9 @@ limitations under the License. // XLA-specific reshape Op. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -95,7 +95,7 @@ class ReshapeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("Reshape", ReshapeOp); +REGISTER_XLA_OP(Name("Reshape"), ReshapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 87d11a38d4c..462267d1504 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -43,7 +43,7 @@ class RetvalOp : public XlaOpKernel { if (frame) { // If 'frame' is non-null, this is an inner function call inside a JIT // compilation. - frame->SetRetval(index_, input); + OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { xla::ComputationDataHandle input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); @@ -58,9 +58,9 @@ class RetvalOp : public XlaOpKernel { if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); - tc.AddConstRetval(index_, dtype_, literal); + OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - tc.AddRetval(index_, input); + tc.AddRetval(index_, dtype_, input); } } } @@ -73,7 +73,7 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP("_Retval", RetvalOp); +REGISTER_XLA_OP(Name("_Retval"), RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc new file mode 100644 index 00000000000..7489321f72f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific reverse Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace { + +class ReverseOp : public XlaOpKernel { + public: + explicit ReverseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // r = tf.reverse(x, revdims) + const TensorShape x_shape = ctx->InputShape(0); + const TensorShape revd_shape = ctx->InputShape(1); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(revd_shape), + errors::InvalidArgument("axes must be a vector, not shape ", + revd_shape.DebugString())); + OP_REQUIRES(ctx, revd_shape.num_elements() == x_shape.dims(), + errors::InvalidArgument("axes ", revd_shape.DebugString(), + " must have same number of elements as" + " than input tensor has dimensions ", + x_shape.DebugString(), ".")); + if (revd_shape.num_elements() == 0) { + ctx->SetOutput(0, ctx->Input(0)); + return; + } + // ComputationBuilder::Rev() requires concrete values for dimensions arg. + xla::Literal lax; + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); + std::vector revdims(x_shape.dims()); + std::copy(lax.preds().begin(), lax.preds().end(), revdims.begin()); + std::vector dimensions; + + for (int d = 0; d < x_shape.dims(); ++d) { + if (revdims[d]) { + dimensions.push_back(d); + } + } + + ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), dimensions)); + } +}; + +REGISTER_XLA_OP(Name("Reverse"), ReverseOp); + +class ReverseV2Op : public XlaOpKernel { + public: + explicit ReverseV2Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // r = tf.reverse(x, axes) + const TensorShape x_shape = ctx->InputShape(0); + const TensorShape axes_shape = ctx->InputShape(1); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(axes_shape), + errors::InvalidArgument("axes must be a vector, not shape ", + axes_shape.DebugString())); + OP_REQUIRES(ctx, axes_shape.num_elements() <= x_shape.dims(), + errors::InvalidArgument("axes ", axes_shape.DebugString(), + " can not have more elements" + " than input tensor has dimensions ", + x_shape.DebugString(), ".")); + // Reverse is a no-op if axes argument is empty. + if (axes_shape.num_elements() == 0) { + ctx->SetOutput(0, ctx->Input(0)); + return; + } + // ComputationBuilder::Rev() requires concrete values for dimensions arg. + std::vector axes; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes)); + + for (int d = 0; d < axes.size(); ++d) { + OP_REQUIRES(ctx, (0 <= axes[d]) && (axes[d] < x_shape.dims()), + errors::InvalidArgument(axes[d], " is out of range [0, ", + x_shape.dims(), ").")); + } + + ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), axes)); + } +}; + +REGISTER_XLA_OP(Name("ReverseV2"), ReverseV2Op); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 0fecc338ca5..8081d3c41c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -40,12 +40,19 @@ class SelectOp : public XlaOpKernel { "'then' and 'else' must have the same size. but received: ", then_shape.DebugString(), " vs. ", else_shape.DebugString())); + xla::ComputationBuilder* builder = ctx->builder(); + + auto cond_handle = ctx->Input(0); + auto then_handle = ctx->Input(1); + auto else_handle = ctx->Input(2); + bool broadcasting = !cond_shape.IsSameSize(then_shape); - if (broadcasting) { - OP_REQUIRES( - ctx, TensorShapeUtils::IsVector(cond_shape), - errors::InvalidArgument("'cond' must be a vector, but saw shape: ", - cond_shape.DebugString())); + bool cond_is_scalar = TensorShapeUtils::IsScalar(cond_shape); + if (broadcasting && !cond_is_scalar) { + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(cond_shape), + errors::InvalidArgument( + "'cond' must be a scalar or a vector, but saw shape: ", + cond_shape.DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then_shape), errors::InvalidArgument( "'then' must be at least a vector, but saw shape: ", @@ -55,15 +62,7 @@ class SelectOp : public XlaOpKernel { "match size of 'cond', but saw: ", then_shape.dim_size(0), " vs. ", cond_shape.num_elements())); - } - xla::ComputationBuilder* builder = ctx->builder(); - - auto cond_handle = ctx->Input(0); - auto then_handle = ctx->Input(1); - auto else_handle = ctx->Input(2); - - if (broadcasting) { // TODO(phawkins): broadcasting on the right seems pretty awkward in // XLA. It seems we have to broadcast on the left and then Reshape // to get the dimensions in the right order. @@ -84,7 +83,7 @@ class SelectOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); }; -REGISTER_XLA_OP("Select", SelectOp); +REGISTER_XLA_OP(Name("Select"), SelectOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 42ae978c3ce..5b6fa64fa82 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -15,9 +15,9 @@ limitations under the License. // XLA-specific sequence and range Ops. -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -138,7 +138,7 @@ class RangeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("Range", RangeOp); +REGISTER_XLA_OP(Name("Range"), RangeOp); class LinSpaceOp : public XlaOpKernel { public: @@ -207,7 +207,7 @@ class LinSpaceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("LinSpace", LinSpaceOp); +REGISTER_XLA_OP(Name("LinSpace"), LinSpaceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index e7eec1cefda..24a99f253d6 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -16,9 +16,9 @@ limitations under the License. // XLA-specific Shape Ops. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -48,7 +48,7 @@ class ShapeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("Shape", ShapeOp); +REGISTER_XLA_OP(Name("Shape"), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -78,7 +78,7 @@ class ShapeNOp : public XlaOpKernel { bool IsExpensive() override { return false; } }; -REGISTER_XLA_OP("ShapeN", ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -94,7 +94,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("Rank", RankOp); +REGISTER_XLA_OP(Name("Rank"), RankOp); class SizeOp : public XlaOpKernel { public: @@ -113,7 +113,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("Size", SizeOp); +REGISTER_XLA_OP(Name("Size"), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: @@ -163,7 +163,7 @@ class ExpandDimsOp : public XlaOpKernel { ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); } }; -REGISTER_XLA_OP("ExpandDims", ExpandDimsOp); +REGISTER_XLA_OP(Name("ExpandDims"), ExpandDimsOp); class SqueezeOp : public XlaOpKernel { public: @@ -225,7 +225,7 @@ class SqueezeOp : public XlaOpKernel { std::unordered_set squeeze_dims_; }; -REGISTER_XLA_OP("Squeeze", SqueezeOp); +REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp); class ZerosLikeOp : public XlaOpKernel { public: @@ -239,7 +239,21 @@ class ZerosLikeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("ZerosLike", ZerosLikeOp); +REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp); + +class OnesLikeOp : public XlaOpKernel { + public: + explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + + auto one = XlaHelpers::One(ctx->builder(), input_type(0)); + ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes())); + } +}; + +REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 8ec77e04afe..482c54a40cf 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -16,9 +16,9 @@ limitations under the License. // XLA-specific Slice Op. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -50,10 +50,13 @@ class SliceOp : public XlaOpKernel { // slice will be an empty handle if the output has no elements. CHECK_EQ(begin.size(), size.size()); std::vector limits; + limits.reserve(begin.size()); for (int i = 0; i < begin.size(); ++i) { limits.push_back(begin[i] + size[i]); } - ctx->SetOutput(0, ctx->builder()->Slice(ctx->Input(0), begin, limits)); + std::vector strides(begin.size(), 1); + ctx->SetOutput(0, ctx->builder()->Slice(ctx->Input(0), begin, limits, + strides)); } private: @@ -115,7 +118,7 @@ void SliceOp::SharedValidation(XlaOpKernelContext* ctx, bool* is_identity, } } -REGISTER_XLA_OP("Slice", SliceOp); +REGISTER_XLA_OP(Name("Slice"), SliceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 06ee5201633..a0d8ab4d73f 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,9 +15,9 @@ limitations under the License. // XLA-specific Ops for softmax. -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -74,8 +74,53 @@ class SoftmaxOp : public XlaOpKernel { bool log_; }; -REGISTER_XLA_OP("Softmax", SoftmaxOp); -REGISTER_XLA_OP("LogSoftmax", SoftmaxOp); +REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); +REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); + +std::pair +CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, + const xla::ComputationDataHandle& logits, + const xla::ComputationDataHandle& labels) { + const xla::Computation& max_func = *ctx->GetOrCreateMax(type); + const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); + + const int kBatchDim = 0; + const int kClassDim = 1; + + xla::ComputationBuilder* b = ctx->builder(); + // Find the max in each batch, resulting in a tensor of shape [batch] + auto logits_max = + b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + + // Subtract the max in batch b from every element in batch b. + // Broadcasts along the batch dimension. + auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); + + // exp(logits - max_logits) + auto exp_shifted_logits = b->Exp(shifted_logits); + + // sum_{class} (exp(logits - max_logits)) + auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type), + add_func, {kClassDim}); + + // log(sum(exp(logits - max_logits))) + auto log_sum_exp = b->Log(sum_exp); + + // sum(-labels * + // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) + // along classes + // (The subtraction broadcasts along the batch dimension.) + xla::ComputationDataHandle loss = b->Reduce( + b->Mul(b->Neg(labels), b->Sub(shifted_logits, log_sum_exp, {kBatchDim})), + XlaHelpers::Zero(b, type), add_func, {kClassDim}); + + // backprop: prob - labels, where + // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) + // (where the division broadcasts along the batch dimension) + xla::ComputationDataHandle backprop = + b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); + return {loss, backprop}; +} class SoftmaxXentWithLogitsOp : public XlaOpKernel { public: @@ -88,65 +133,95 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel { OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape), errors::InvalidArgument( "logits and labels must be same size: logits_size=", - logits_shape.DebugString(), " labels_size=", - labels_shape.DebugString())); + logits_shape.DebugString(), + " labels_size=", labels_shape.DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), errors::InvalidArgument("logits must be 2-dimensional")); // As we already tested that both inputs have the same shape no need to // check that "labels" is a matrix too. - // loss is 1-D (one per example), and size is batch_size. - - const int kBatchDim = 0; - const int kClassDim = 1; - const DataType type = input_type(0); - xla::ComputationBuilder* b = ctx->builder(); auto logits = ctx->Input(0); auto labels = ctx->Input(1); - const xla::Computation& max_func = *ctx->GetOrCreateMax(type); - const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); - - // Find the max in each batch, resulting in a tensor of shape [batch] - auto logits_max = - b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); - - // Subtract the max in batch b from every element in batch b. - // Broadcasts along the batch dimension. - auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); - - // exp(logits - max_logits) - auto exp_shifted_logits = b->Exp(shifted_logits); - - // sum_{class} (exp(logits - max_logits)) - auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type), - add_func, {kClassDim}); - - // log(sum(exp(logits - max_logits))) - auto log_sum_exp = b->Log(sum_exp); - - // sum(-labels * - // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) - // along classes - // (The subtraction broadcasts along the batch dimension.) - xla::ComputationDataHandle loss = - b->Reduce(b->Mul(b->Neg(labels), - b->Sub(shifted_logits, log_sum_exp, {kBatchDim})), - XlaHelpers::Zero(b, type), add_func, {kClassDim}); - - // backprop: prob - labels, where - // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) - // (where the division broadcasts along the batch dimension) - xla::ComputationDataHandle backprop = - b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); - + xla::ComputationDataHandle loss, backprop; + std::tie(loss, backprop) = + CrossEntropyWithLogits(ctx, type, logits, labels); ctx->SetOutput(0, loss); ctx->SetOutput(1, backprop); } }; -REGISTER_XLA_OP("SoftmaxCrossEntropyWithLogits", SoftmaxXentWithLogitsOp); +REGISTER_XLA_OP(Name("SoftmaxCrossEntropyWithLogits"), SoftmaxXentWithLogitsOp); + +class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { + public: + explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape logits_shape = ctx->InputShape(0); + const TensorShape labels_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), + errors::InvalidArgument("logits must be 2-D, but got shape ", + logits_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_shape), + errors::InvalidArgument("labels must be 1-D, but got shape ", + labels_shape.DebugString())); + OP_REQUIRES(ctx, logits_shape.dim_size(0) == labels_shape.dim_size(0), + errors::InvalidArgument( + "logits and labels must have the same first dimension, " + "got logits shape ", + logits_shape.DebugString(), " and labels shape ", + labels_shape.DebugString())); + OP_REQUIRES(ctx, logits_shape.dim_size(1) > 0, + errors::InvalidArgument( + "Must have at least one class, but got logits shape ", + logits_shape.DebugString())); + + int64 batch_size = logits_shape.dim_size(0); + int64 depth = logits_shape.dim_size(1); + + DataType logits_type = input_type(0); + DataType indices_type = input_type(1); + + xla::ComputationDataHandle indices = ctx->Input(1); + + xla::ComputationBuilder* builder = ctx->builder(); + xla::ComputationDataHandle labels; + OP_REQUIRES_OK(ctx, + XlaHelpers::OneHot( + builder, depth, /*axis=*/1, input_type(1), labels_shape, + indices, XlaHelpers::One(builder, logits_type), + XlaHelpers::Zero(builder, logits_type), &labels)); + + // If any of the indices are out of range, we must populate the labels with + // NaNs to obey the interface contract of + // tf.nn.sparse_softmax_cross_entropy_with_logits. + // Builds a vector of {batch_size} that is 0 if the index is in range, or + // NaN otherwise; then add that vector to the labels to force out-of-range + // values to NaNs. + xla::ComputationDataHandle nan_or_zero = builder->Select( + builder->LogicalAnd( + builder->Le(XlaHelpers::Zero(builder, indices_type), indices), + builder->Lt(indices, XlaHelpers::IntegerLiteral( + builder, indices_type, depth))), + builder->Broadcast(XlaHelpers::Zero(builder, logits_type), + {batch_size}), + builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), + {batch_size})); + labels = builder->Add(labels, nan_or_zero, {0}); + + xla::ComputationDataHandle loss, backprop; + std::tie(loss, backprop) = + CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels); + ctx->SetOutput(0, loss); + ctx->SetOutput(1, backprop); + } +}; + +REGISTER_XLA_OP(Name("SparseSoftmaxCrossEntropyWithLogits"), + SparseSoftmaxXentWithLogitsOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc new file mode 100644 index 00000000000..f15b354cb26 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -0,0 +1,190 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +void SpaceToBatch(XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, DataType input_dtype, + const TensorShape& input_tensor_shape, + gtl::ArraySlice block_shape, + const xla::Literal& paddings) { + const int input_rank = input_tensor_shape.dims(); + const gtl::InlinedVector input_shape = + input_tensor_shape.dim_sizes(); + const int block_rank = block_shape.size(); + + OP_REQUIRES( + ctx, input_rank >= 1 + block_rank, + errors::InvalidArgument("input rank should be >= ", 1 + block_rank, + " instead of ", input_rank)); + gtl::ArraySlice remainder_shape(input_shape); + remainder_shape.remove_prefix(1 + block_rank); + + OP_REQUIRES( + ctx, + xla::ShapeUtil::Rank(paddings.shape()) == 2 && + block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && + 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), + errors::InvalidArgument("paddings should have shape [", block_rank, + ", 2] instead of ", + xla::ShapeUtil::HumanString(paddings.shape()))); + + xla::ComputationBuilder* b = ctx->builder(); + + // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the + // input according to `paddings` to produce `padded` of shape `padded_shape`. + xla::PaddingConfig padding_config; + std::vector padded_shape(input_shape.begin(), input_shape.end()); + int64 block_num_elems = 1LL; + padding_config.add_dimensions(); // Don't pad the batch dimension. + for (int i = 0; i < block_rank; ++i) { + auto* dim = padding_config.add_dimensions(); + int64 pad_start = xla::LiteralUtil::Get(paddings, {i, 0}); + int64 pad_end = xla::LiteralUtil::Get(paddings, {i, 1}); + OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0, + errors::InvalidArgument("Paddings must be non-negative")); + dim->set_edge_padding_low(pad_start); + dim->set_edge_padding_high(pad_end); + padded_shape[1 + i] += pad_start + pad_end; + block_num_elems *= block_shape[i]; + } + // Don't pad the remainder dimensions. + for (int i = 0; i < remainder_shape.size(); ++i) { + padding_config.add_dimensions(); + } + OP_REQUIRES(ctx, block_num_elems > 0, + errors::InvalidArgument( + "The product of the block dimensions must be positive")); + + xla::ComputationDataHandle padded = + b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); + + // 2. Reshape `padded` to `reshaped_padded` of shape: + // + // [batch] + + // [padded_shape[1] / block_shape[0], + // block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1], + // block_shape[M-1]] + + // remaining_shape + const int64 batch_size = input_shape[0]; + std::vector reshaped_padded_shape(input_rank + block_rank); + reshaped_padded_shape[0] = batch_size; + for (int i = 0; i < block_rank; ++i) { + OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0, + errors::InvalidArgument("padded_shape[", 1 + i, + "]=", padded_shape[1 + i], + " is not divisible by block_shape[", i, + "]=", block_shape[i])); + + reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i]; + reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + reshaped_padded_shape.begin() + 1 + 2 * block_rank); + + xla::ComputationDataHandle reshaped_padded = + b->Reshape(padded, reshaped_padded_shape); + + // 3. Permute dimensions of `reshaped_padded` to produce + // `permuted_reshaped_padded` of shape: + // + // block_shape + + // [batch] + + // [padded_shape[1] / block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1]] + + // remaining_shape + std::vector permutation(reshaped_padded_shape.size()); + for (int i = 0; i < block_rank; ++i) { + permutation[i] = 1 + 2 * i + 1; + permutation[block_rank + 1 + i] = 1 + 2 * i; + } + permutation[block_rank] = 0; + std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), + 1 + block_rank * 2); + xla::ComputationDataHandle permuted_reshaped_padded = + b->Transpose(reshaped_padded, permutation); + + // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the + // batch dimension, producing an output tensor of shape: + // + // [batch * prod(block_shape)] + + // [padded_shape[1] / block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1]] + + // remaining_shape + // Determine the length of the prefix of block dims that can be combined + // into the batch dimension due to having no padding and block_shape=1. + std::vector output_shape(input_rank); + output_shape[0] = batch_size * block_num_elems; + for (int i = 0; i < block_rank; ++i) { + output_shape[1 + i] = padded_shape[1 + i] / block_shape[i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + output_shape.begin() + 1 + block_rank); + + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped_padded, output_shape); + ctx->SetOutput(0, output); +} + +class SpaceToBatchNDOp : public XlaOpKernel { + public: + explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + std::vector block_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape)); + + xla::Literal paddings; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings)); + + SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + block_shape, paddings); + } +}; +REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp); + +class SpaceToBatchOp : public XlaOpKernel { + public: + explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::Literal paddings; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings)); + + SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + {block_size_, block_size_}, paddings); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 18c4c648db1..42bde900422 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -16,9 +16,9 @@ limitations under the License. // XLA-specific Ops for split. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -77,14 +77,14 @@ class SplitOp : public XlaOpKernel { // The vectors we will use to define the slice. The entry for the // split dimensions varies for each output. - std::vector begin; - std::vector limits; + std::vector begin(input_shape.dims(), 0); + std::vector limits(input_shape.dims()); + std::vector strides(input_shape.dims(), 1); for (int i = 0; i < input_shape.dims(); ++i) { // Initially set up the limits to be the full size of the input: // the split dimension is filled in below. int64 dim = input_shape.dim_size(i); - begin.push_back(0); - limits.push_back(dim); + limits[i] = dim; } auto input = ctx->Input(1); @@ -94,12 +94,12 @@ class SplitOp : public XlaOpKernel { // Slice out the ith split from the split dimension. begin[split_dim] = i * slice_size; limits[split_dim] = (i + 1) * slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits)); + ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); } } }; -REGISTER_XLA_OP("Split", SplitOp); +REGISTER_XLA_OP(Name("Split"), SplitOp); class SplitVOp : public XlaOpKernel { public: @@ -188,7 +188,7 @@ class SplitVOp : public XlaOpKernel { std::vector begin(input_shape.dims(), 0); auto dim_sizes = input_shape.dim_sizes(); std::vector limits(dim_sizes.begin(), dim_sizes.end()); - + std::vector strides(input_shape.dims(), 1); for (int i = 0; i < num_split; ++i) { TensorShape output_shape(input_shape); int slice_size = split_sizes_vec[i]; @@ -196,13 +196,13 @@ class SplitVOp : public XlaOpKernel { // Slice out the ith split from the split dimension. limits[split_dim] = begin[split_dim] + slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits)); + ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); begin[split_dim] = limits[split_dim]; } } }; -REGISTER_XLA_OP("SplitV", SplitVOp); +REGISTER_XLA_OP(Name("SplitV"), SplitVOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 83bf24814f4..9eb68998310 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/core/util/strided_slice_op.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -76,29 +76,30 @@ class StridedSliceOp : public XlaOpKernel { &dummy, &dummy, &begin, &end, &strides)); gtl::InlinedVector dimensions_to_reverse; - gtl::InlinedVector slice_begin, slice_end; + gtl::InlinedVector slice_begin, slice_end, slice_strides; + for (int i = 0; i < begin.size(); ++i) { - // TODO(phawkins): implement strides != 1 when b/30878775 is fixed. - OP_REQUIRES( - ctx, strides[i] == 1 || strides[i] == -1, - errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); if (strides[i] > 0) { slice_begin.push_back(begin[i]); slice_end.push_back(end[i]); + slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. - slice_begin.push_back(end[i] + 1); - slice_end.push_back(begin[i] + 1); + slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); + slice_end.push_back(input_shape.dim_size(i) - end[i] - 1); + slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } } - xla::ComputationDataHandle slice = - ctx->builder()->Slice(ctx->Input(0), slice_begin, slice_end); + + xla::ComputationDataHandle slice = ctx->Input(0); if (!dimensions_to_reverse.empty()) { slice = ctx->builder()->Rev(slice, dimensions_to_reverse); } + slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides); + slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); } @@ -109,7 +110,7 @@ class StridedSliceOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP("StridedSlice", StridedSliceOp); +REGISTER_XLA_OP(Name("StridedSlice"), StridedSliceOp); class StridedSliceGradOp : public XlaOpKernel { public: @@ -217,7 +218,120 @@ class StridedSliceGradOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP("StridedSliceGrad", StridedSliceGradOp); +REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp); + +class StridedSliceAssignOp : public XlaOpKernel { + public: + explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape final_shape; + gtl::InlinedVector begin; + gtl::InlinedVector end; + gtl::InlinedVector strides; + + xla::Literal begin_literal, end_literal, strides_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); + + Tensor begin_tensor, end_tensor, strides_tensor; + OP_REQUIRES_OK( + ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); + OP_REQUIRES_OK(ctx, + LiteralToHostTensor(end_literal, index_type_, &end_tensor)); + OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, + &strides_tensor)); + + DataType lhs_type; + TensorShape lhs_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape)); + + const TensorShape rhs_shape = ctx->InputShape(4); + + TensorShape dummy_processing_shape; + ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); + ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape( + &dummy_processing_shape); + bool dummy = false; + OP_REQUIRES_OK( + ctx, ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, + ShapeReadWriteFromTensorShape(&lhs_shape), begin_mask_, + end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, + &dummy, &dummy, &begin, &end, &strides)); + + if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { + // DynamicUpdateSlice does not allow 0-element updates. We should probably + // check that rhs_shape can be broadcast to final_shape, but that is + // probably better handled when implementing broadcasting more generally. + return; + } + + // TODO(aselle): This check is too strong, we only should need + // input_shape to be broadcastable to final_shape + OP_REQUIRES(ctx, final_shape == rhs_shape, + errors::Unimplemented( + "sliced l-value shape ", final_shape.DebugString(), + " does not match r-value shape ", rhs_shape.DebugString(), + ". Automatic broadcasting not yet implemented.")); + + xla::ComputationDataHandle lhs; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs)); + + xla::ComputationDataHandle rhs = ctx->Input(4); + + gtl::InlinedVector dimensions_to_reverse; + gtl::InlinedVector slice_begin, slice_dims; + for (int i = 0; i < begin.size(); ++i) { + // TODO(phawkins): implement strides != 1 + OP_REQUIRES( + ctx, strides[i] == 1 || strides[i] == -1, + errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); + if (strides[i] > 0) { + slice_begin.push_back(begin[i]); + slice_dims.push_back(end[i] - begin[i]); + } else { + // Negative stride: swap begin and end, add 1 because the interval + // is semi-open, and mark the dimension to be reversed. + slice_begin.push_back(end[i] + 1); + slice_dims.push_back(begin[i] - end[i]); + dimensions_to_reverse.push_back(i); + } + } + + if (!dimensions_to_reverse.empty()) { + rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse); + } + rhs = ctx->builder()->Reshape(rhs, slice_dims); + + if (lhs_shape.dims() == 0) { + // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix + // and remove this workaround. + lhs = rhs; + } else { + lhs = ctx->builder()->DynamicUpdateSlice( + lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); + } + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs)); + } + + private: + int32 begin_mask_, end_mask_; + int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + DataType index_type_; +}; + +REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc new file mode 100644 index 00000000000..deee7dd44db --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -0,0 +1,540 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA TensorArray operators. + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// Since the element shape is not always provided to the TensorArrayV3 operator, +// we must support lazily initialization of the TensorArray at the time of the +// first write. +// If a TensorArray `var` has not been initialized, constructs storage for the +// TensorArray with elements of `elem_shape`. For both initialized and +// uninitialized TensorArrays, checks that the tensor has a type compatible with +// 'dtype' and shape compatible with 'elem_shape'. +Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, + XlaVariable* var, DataType dtype, + const TensorShape& elem_shape) { + if (var->type != dtype) { + return errors::InvalidArgument( + "TensorArray dtype is ", DataTypeString(var->type), + " but op has dtype ", DataTypeString(dtype), "."); + } + + TF_RET_CHECK(var->tensor_array_size >= 0) + << var->name << " size " << var->tensor_array_size; + TensorShape ta_shape; + ta_shape.AddDim(var->tensor_array_size); + ta_shape.AppendShape(elem_shape); + + if (var->value.handle() == 0) { + // TensorArray has not been initialized. + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type); + var->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + } else { + // Checks the elem_shape matches the TensorArray shape. + auto shape_or_status = builder->GetShape(var->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + TensorShape shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + if (ta_shape != shape) { + return errors::InvalidArgument( + "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ", + shape.DebugString()); + } + } + return Status::OK(); +} + +// Pads 'x' with 'count' zero indices. 'x' must have 1 element. +xla::ComputationDataHandle PadIndexWithZeros( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + int count) { + xla::ComputationDataHandle zero = builder->ConstantR1({0}); + std::vector xs(count + 1, zero); + xs[0] = builder->Reshape(x, {1}); + return builder->ConcatInDim(xs, 0); +} + +// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the +// relevant slice of 'operand'. +xla::ComputationDataHandle DynamicAddSlice( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand, + const xla::ComputationDataHandle& update, + const gtl::ArraySlice& update_dims, + const xla::ComputationDataHandle& start_indices) { + xla::ComputationDataHandle current = + builder->DynamicSlice(operand, start_indices, update_dims); + xla::ComputationDataHandle sum = builder->Add(current, update); + return builder->DynamicUpdateSlice(operand, sum, start_indices); +} + +class TensorArrayOp : public XlaOpKernel { + public: + explicit TensorArrayOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_shape", &element_shape_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + bool dynamic_size; + OP_REQUIRES_OK(ctx, ctx->GetAttr("dynamic_size", &dynamic_size)); + OP_REQUIRES( + ctx, !dynamic_size, + errors::Unimplemented( + "TensorArrays with dynamic size are not supported by XLA.")); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_array_name", &tensor_array_name_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + int64 size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size)); + OP_REQUIRES(ctx, size >= 0, + errors::InvalidArgument("TensorArray size must be >= 0")); + + xla::ComputationBuilder* b = ctx->builder(); + b->set_die_immediately_on_error(true); + + // Initializes the TensorArray value if we know the element shape. + // Otherwise, defer initialization to the first write. + xla::ComputationDataHandle value; + if (element_shape_.IsFullyDefined()) { + TensorShape shape; + CHECK(element_shape_.AsTensorShape(&shape)); + TensorShape ta_shape; + ta_shape.AddDim(size); + ta_shape.AppendShape(shape); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_); + value = b->Broadcast(zero, ta_shape.dim_sizes()); + } + + XlaContext& xc = XlaContext::Get(ctx); + XlaVariable* var; + string name = strings::StrCat("TensorArray: ", tensor_array_name_); + OP_REQUIRES_OK(ctx, + xc.CreateVariable(-1, std::move(name), dtype_, value, &var)); + var->tensor_array_size = size; + ctx->SetVariableOutput(0, var); + ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); + } + + private: + PartialTensorShape element_shape_; + DataType dtype_; + string tensor_array_name_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp); + +class TensorArrayWriteOp : public XlaOpKernel { + public: + explicit TensorArrayWriteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape elem_shape = ctx->InputShape(2); + + // Initializes the TensorArray, if the element shape was not known at + // construction time. + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + + xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle index = ctx->Input(1); + xla::ComputationDataHandle value = ctx->Input(2); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = b->Reshape(value, slice_shape.dim_sizes()); + + xla::ComputationDataHandle written = + DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written)); + ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayWriteOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayWriteV3"), TensorArrayWriteOp); + +class TensorArrayReadOp : public XlaOpKernel { + public: + explicit TensorArrayReadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType ta_type; + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); + OP_REQUIRES(ctx, ta_type == dtype_, + errors::InvalidArgument( + "TensorArray dtype is ", DataTypeString(ta_type), + " but Op requested dtype ", DataTypeString(dtype_), ".")); + OP_REQUIRES(ctx, ta_shape.dims() >= 1, + errors::InvalidArgument("TensorArray rank must be >= 1")); + + xla::ComputationBuilder* b = ctx->builder(); + + xla::ComputationDataHandle ta; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + xla::ComputationDataHandle index = ctx->Input(1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1); + + auto slice_shape = ta_shape.dim_sizes(); + slice_shape[0] = 1LL; + + xla::ComputationDataHandle read = + b->DynamicSlice(ta, start_indices, slice_shape); + + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + ctx->SetOutput(0, b->Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayReadOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayReadV3"), TensorArrayReadOp); + +class TensorArrayGatherOp : public XlaOpKernel { + public: + explicit TensorArrayGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType ta_type; + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); + OP_REQUIRES(ctx, ta_type == dtype_, + errors::InvalidArgument("TensorArray type mismatch")); + OP_REQUIRES(ctx, ta_shape.dims() >= 1, + errors::InvalidArgument("TensorArray rank must be >= 1")); + + const TensorShape indices_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, indices_shape.dims() >= 1, + errors::InvalidArgument("indices must be rank 1")); + const int num_indices = indices_shape.dim_size(0); + auto indices = ctx->Input(1); + + xla::ComputationBuilder* b = ctx->builder(); + + xla::ComputationDataHandle ta; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + + // For each index in `indices`, add the corresponding slice to `slices`. + std::vector slices(num_indices); + for (int i = 0; i < num_indices; ++i) { + // Slices the i-th index out of `indices`, and pads it with zeros in the + // minor dimensions to form an index into the TensorArray storage. + auto index = b->Slice(indices, {i}, {i + 1}, {1}); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1); + + auto slice_shape = ta_shape.dim_sizes(); + slice_shape[0] = 1LL; + + slices[i] = b->DynamicSlice(ta, start_indices, slice_shape); + } + + xla::ComputationDataHandle gather; + if (slices.empty()) { + auto shape = ta_shape.dim_sizes(); + shape[0] = 0; + gather = b->Broadcast(XlaHelpers::Zero(b, dtype_), shape); + } else { + gather = b->ConcatInDim(slices, 0); + } + ctx->SetOutput(0, gather); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGatherOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayGatherV3"), TensorArrayGatherOp); + +class TensorArrayScatterOp : public XlaOpKernel { + public: + explicit TensorArrayScatterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + const TensorShape value_shape = ctx->InputShape(2); + + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + TensorShape elem_shape = value_shape; + elem_shape.RemoveDim(0); + OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + + const TensorShape indices_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, indices_shape.dims() >= 1, + errors::InvalidArgument("indices must be rank 1")); + const int num_indices = indices_shape.dim_size(0); + const xla::ComputationDataHandle indices = ctx->Input(1); + + xla::ComputationDataHandle ta = var->value; + const xla::ComputationDataHandle value = ctx->Input(2); + + auto slice_dims = value_shape.dim_sizes(); + slice_dims[0] = 1LL; + + std::vector value_starts(value_shape.dims(), 0); + auto value_ends = value_shape.dim_sizes(); + + std::vector value_strides(value_shape.dims(), 1); + + // For every (index, value) pair, update the corresponding TensorArray + // storage. + for (int i = 0; i < num_indices; ++i) { + // Slice out part of the value. + value_starts[0] = i; + value_ends[0] = i + 1; + auto slice = b->Slice(value, value_starts, value_ends, value_strides); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + } + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayScatterOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayScatterV3"), TensorArrayScatterOp); + +class TensorArrayConcatOp : public XlaOpKernel { + public: + explicit TensorArrayConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType ta_type; + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); + OP_REQUIRES(ctx, ta_type == dtype_, + errors::InvalidArgument("TensorArray type mismatch")); + OP_REQUIRES(ctx, ta_shape.dims() >= 1, + errors::InvalidArgument("TensorArray rank must be >= 1")); + + xla::ComputationBuilder* b = ctx->builder(); + + xla::ComputationDataHandle ta; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + + auto ta_dims = ta_shape.dim_sizes(); + std::vector shape(ta_dims.begin() + 1, ta_dims.end()); + shape[0] *= ta_shape.dim_size(0); + ctx->SetOutput(0, b->Reshape(ta, shape)); + + Tensor lengths(DT_INT64, {ta_dims[0]}); + auto lengths_vec = lengths.vec(); + for (int i = 0; i < ta_dims[0]; ++i) { + lengths_vec(i) = ta_dims[1]; + } + ctx->SetConstantOutput(1, lengths); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayConcatOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayConcatV3"), TensorArrayConcatOp); + +class TensorArraySplitOp : public XlaOpKernel { + public: + explicit TensorArraySplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + std::vector lengths; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths)); + + int64 length = 0; + if (!lengths.empty()) { + length = lengths[0]; + for (int i = 1; i < lengths.size(); ++i) { + OP_REQUIRES(ctx, lengths[i] == length, + errors::InvalidArgument("lengths must be equal: ", length, + " vs. ", lengths[i])); + } + } + + TensorShape value_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, value_shape.dims() >= 1, + errors::InvalidArgument("value must have rank >= 1, got ", + value_shape.DebugString())); + TensorShape elem_shape = value_shape; + elem_shape.set_dim(0, length); + + xla::ComputationBuilder* b = ctx->builder(); + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + xla::ComputationDataHandle ta = var->value; + + TensorShape ta_shape; + ta_shape.AddDim(var->tensor_array_size); + ta_shape.AppendShape(elem_shape); + + OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size, + errors::InvalidArgument( + "TensorArray's size is not equal to the size of lengths (", + lengths.size(), " vs. ", var->tensor_array_size, ")")); + + const xla::ComputationDataHandle value = ctx->Input(1); + + OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(), + errors::InvalidArgument("mismatched element count ", + value_shape.DebugString(), " vs. ", + ta_shape.DebugString())); + + ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + + ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp); +}; + +REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp); + +class TensorArraySizeOp : public XlaOpKernel { + public: + explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + Tensor size_tensor(DT_INT32, {}); + size_tensor.scalar()() = static_cast(var->tensor_array_size); + ctx->SetConstantOutput(0, size_tensor); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySizeOp); +}; + +REGISTER_XLA_OP(Name("TensorArraySizeV3"), TensorArraySizeOp); + +class TensorArrayGradOp : public XlaOpKernel { + public: + explicit TensorArrayGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("source", &source_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + XlaVariable* var; + OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + + DataType ta_type; + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); + OP_REQUIRES(ctx, ta_shape.dims() >= 1, + errors::InvalidArgument("TensorArray rank must be >= 1")); + + // Finds or looks up the corresponding gradient TensorArray, which stores + // gradients computed during backpropagation. + XlaVariable*& gradient = var->tensor_array_gradient[source_]; + if (!gradient) { + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type); + xla::ComputationDataHandle value = + b->Broadcast(zero, ta_shape.dim_sizes()); + + XlaContext& xc = XlaContext::Get(ctx); + string name = strings::StrCat("TensorArrayGrad: ", var->name); + OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type, + value, &gradient)); + gradient->tensor_array_size = var->tensor_array_size; + } + + ctx->SetVariableOutput(0, gradient); + ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); + } + + private: + string source_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 45ac5e12c74..4cc2eb8f877 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -122,7 +122,7 @@ class TileOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TileOp); }; -REGISTER_XLA_OP("Tile", TileOp); +REGISTER_XLA_OP(Name("Tile"), TileOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc new file mode 100644 index 00000000000..e9ac1ee91b8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -0,0 +1,475 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class ResourceApplyGradientDescent : public XlaOpKernel { + public: + explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle handle; + xla::ComputationBuilder* b = ctx->builder(); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2))); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + } +}; +REGISTER_XLA_OP(Name("ResourceApplyGradientDescent"), + ResourceApplyGradientDescent); + +class ResourceApplyMomentum : public XlaOpKernel { + public: + explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + DataType type = ctx->input_type(2); + + DataType var_type, accum_type; + TensorShape var_shape, accum_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); + + OP_REQUIRES( + ctx, type == var_type && type == accum_type, + errors::InvalidArgument( + "Types of variable arguments to ResourceApplyMomentum must match: ", + DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", + DataTypeString(accum_type))); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + TensorShape momentum_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), + errors::InvalidArgument("momentum is not a scalar: ", + momentum_shape.DebugString())); + + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); + + xla::ComputationDataHandle lr = ctx->Input(2); + xla::ComputationDataHandle grad = ctx->Input(3); + xla::ComputationDataHandle momentum = ctx->Input(4); + + accum = b->Add(b->Mul(accum, momentum), grad); + if (use_nesterov_) { + // See https://github.com/tensorflow/tensorflow/pull/2798 for an + // explanation of the reparameterization used here. + var = b->Sub( + var, b->Add(b->Mul(grad, lr), b->Mul(b->Mul(accum, momentum), lr))); + } else { + var = b->Sub(var, b->Mul(accum, lr)); + } + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); + } + + private: + bool use_nesterov_; +}; +REGISTER_XLA_OP(Name("ResourceApplyMomentum"), ResourceApplyMomentum); + +class ResourceApplyAdagrad : public XlaOpKernel { + public: + explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + DataType type = ctx->input_type(2); + + DataType var_type, accum_type; + TensorShape var_shape, accum_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); + + OP_REQUIRES( + ctx, type == var_type && type == accum_type, + errors::InvalidArgument( + "Types of variable arguments to ResourceApplyAdagrad must match: ", + DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", + DataTypeString(accum_type))); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); + xla::ComputationDataHandle lr = ctx->Input(2); + xla::ComputationDataHandle grad = ctx->Input(3); + + accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); + var = b->Sub( + var, b->Mul(b->Mul(grad, lr), + b->Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); + } +}; +REGISTER_XLA_OP(Name("ResourceApplyAdagrad"), ResourceApplyAdagrad); + +class ResourceApplyAdam : public XlaOpKernel { + public: + explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType var_type, m_type, v_type; + TensorShape var_shape, m_shape, v_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape)); + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape)); + + OP_REQUIRES( + ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type, + errors::InvalidArgument( + "Types of variable arguments to ResourceApplyRMSProp must match: ", + DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ", + DataTypeString(m_type), " vs. ", DataTypeString(v_type))); + + TensorShape beta1_power_shape = ctx->InputShape(3); + TensorShape beta2_power_shape = ctx->InputShape(4); + TensorShape lr_shape = ctx->InputShape(5); + TensorShape beta1_shape = ctx->InputShape(6); + TensorShape beta2_shape = ctx->InputShape(7); + TensorShape epsilon_shape = ctx->InputShape(8); + TensorShape grad_shape = ctx->InputShape(9); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_shape), + errors::InvalidArgument("beta2_power is not a scalar: ", + beta2_power_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar : ", + lr_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), + errors::InvalidArgument("var and m do not have the same shape", + var_shape.DebugString(), " ", + m_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape), + errors::InvalidArgument("var and v do not have the same shape", + var_shape.DebugString(), " ", + v_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::ComputationDataHandle var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v)); + xla::ComputationDataHandle beta1_power = ctx->Input(3); + xla::ComputationDataHandle beta2_power = ctx->Input(4); + xla::ComputationDataHandle lr = ctx->Input(5); + xla::ComputationDataHandle beta1 = ctx->Input(6); + xla::ComputationDataHandle beta2 = ctx->Input(7); + xla::ComputationDataHandle epsilon = ctx->Input(8); + xla::ComputationDataHandle grad = ctx->Input(9); + + // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) + // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t + // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t + // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + + xla::ComputationDataHandle alpha = + b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), + b->Sub(one, beta1_power)); + m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); + v = b->Add(v, b->Mul(b->Sub(b->Pow(grad, two), v), b->Sub(one, beta2))); + var = + b->Sub(var, b->Div(b->Mul(m, alpha), b->Add(b->Pow(v, half), epsilon))); + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdam"), ResourceApplyAdam); + +class ResourceApplyRMSProp : public XlaOpKernel { + public: + explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + DataType type = ctx->input_type(3); + + DataType var_type, ms_type, mom_type; + TensorShape var_shape, ms_shape, mom_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape)); + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape)); + + OP_REQUIRES( + ctx, type == var_type && type == ms_type && type == mom_type, + errors::InvalidArgument( + "Types of variable arguments to ResourceApplyRMSProp must match: ", + DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ", + DataTypeString(ms_type), " vs. ", DataTypeString(mom_type))); + + TensorShape lr_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + TensorShape rho_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), + errors::InvalidArgument("rho is not a scalar: ", + rho_shape.DebugString())); + TensorShape momentum_shape = ctx->InputShape(5); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), + errors::InvalidArgument("momentum is not a scalar: ", + momentum_shape.DebugString())); + TensorShape epsilon_shape = ctx->InputShape(6); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + TensorShape grad_shape = ctx->InputShape(7); + + // var should be the same shape as mom and ms. + OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape), + errors::InvalidArgument("var and ms do not have the same shape", + var_shape.DebugString(), " ", + ms_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(mom_shape), + errors::InvalidArgument( + "var and mom do not have the same shape", + var_shape.DebugString(), " ", mom_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::ComputationDataHandle var, ms, mom; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom)); + xla::ComputationDataHandle lr = ctx->Input(3); + xla::ComputationDataHandle rho = ctx->Input(4); + xla::ComputationDataHandle momentum = ctx->Input(5); + xla::ComputationDataHandle epsilon = ctx->Input(6); + xla::ComputationDataHandle grad = ctx->Input(7); + + // ms <- rho * ms_{t-1} + (1-rho) * grad * grad + // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) + // var <- var - mom + // + // We use an alternate formulation of the ms equation: + // + // ms <- ms + (grad**2 - ms) * (1 - rho) + // + // Which expands to: + // + // ms <- ms + grad**2 - rho * grad ** 2 - ms + ms * rho + // + // Which simplifies to: + // + // ms <- grad**2 (1 - rho) + ms * rho + // + // Which is the equation listed above. + xla::ComputationDataHandle new_ms = b->Add( + ms, + b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms), + b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); + xla::ComputationDataHandle new_mom = + b->Add(b->Mul(mom, momentum), + b->Div(b->Mul(grad, lr), + b->Pow(b->Add(new_ms, epsilon), + XlaHelpers::FloatLiteral(b, type, 0.5)))); + xla::ComputationDataHandle new_var = b->Sub(var, new_mom); + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom)); + } +}; +REGISTER_XLA_OP(Name("ResourceApplyRMSProp"), ResourceApplyRMSProp); + +class ResourceApplyFtrl : public XlaOpKernel { + public: + explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + DataType var_type, accum_type, linear_type; + TensorShape var_shape, accum_shape, linear_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); + OP_REQUIRES_OK( + ctx, ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape)); + + OP_REQUIRES( + ctx, + dtype_ == var_type && dtype_ == accum_type && dtype_ == linear_type, + errors::InvalidArgument( + "Types of variable arguments to ResourceApplyFtrl must match: ", + DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " and ", + DataTypeString(accum_type), " and ", DataTypeString(linear_type))); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape), + errors::InvalidArgument( + "var and linear do not have the same shape", + var_shape.DebugString(), " ", linear_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + TensorShape lr_shape = ctx->InputShape(4); + TensorShape l1_shape = ctx->InputShape(5); + TensorShape l2_shape = ctx->InputShape(6); + TensorShape lr_power_shape = ctx->InputShape(7); + + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape), + errors::InvalidArgument("lr_power is not a scalar: ", + lr_power_shape.DebugString())); + + xla::ComputationDataHandle var, accum, linear; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear)); + xla::ComputationDataHandle grad = ctx->Input(3); + xla::ComputationDataHandle lr = ctx->Input(4); + xla::ComputationDataHandle l1 = ctx->Input(5); + xla::ComputationDataHandle l2 = ctx->Input(6); + xla::ComputationDataHandle lr_power = ctx->Input(7); + + // new_accum = accum + grad * grad + // linear += grad - (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var + // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2 + // var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 + // accum = new_accum + + xla::ComputationDataHandle zero_broadcast = b->Broadcast( + XlaHelpers::FloatLiteral(b, dtype_, 0.0), var_shape.dim_sizes()); + xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + + xla::ComputationDataHandle new_accum = b->Add(accum, b->Pow(grad, two)); + xla::ComputationDataHandle new_accum_lr_pow = + b->Pow(new_accum, b->Neg(lr_power)); + xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); + linear = b->Add( + linear, + b->Sub(grad, b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), + var))); + xla::ComputationDataHandle quadratic = + b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); + xla::ComputationDataHandle pre_shrink = + b->Div(b->Sub(b->Mul(l1, b->Sign(linear)), linear), quadratic); + var = b->Select(b->Gt(b->Abs(linear), l1), pre_shrink, zero_broadcast); + accum = new_accum; + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, linear)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 2840abc8782..2fc5d40d105 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/core/kernels/transpose_op.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -87,7 +87,7 @@ class TransposeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("Transpose", TransposeOp); +REGISTER_XLA_OP(Name("Transpose"), TransposeOp); // InvertPermutation frequently forms part of the gradient of Transpose. // @@ -128,7 +128,8 @@ class InvertPermutationOp : public XlaOpKernel { } }; -REGISTER_XLA_OP("InvertPermutation", InvertPermutationOp); +REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), + InvertPermutationOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index c3ba1a7a8b0..abe4949f5db 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -16,8 +16,8 @@ limitations under the License. // Native XLA implementations of simple unary Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -28,10 +28,10 @@ namespace { // A subclass of a TlaUnaryOp must build the lambda computation that // describes the scalar->scalar function to apply to each element of // the input. -#define XLAJIT_MAKE_UNARY(Name, COMPUTATION) \ - class Name##Op : public XlaOpKernel { \ +#define XLAJIT_MAKE_UNARY(NAME, COMPUTATION) \ + class NAME##Op : public XlaOpKernel { \ public: \ - explicit Name##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ + explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ void Compile(XlaOpKernelContext* ctx) { \ xla::ComputationBuilder* b = ctx->builder(); \ xla::ComputationDataHandle x = ctx->Input(0); \ @@ -39,7 +39,7 @@ namespace { ctx->SetOutput(0, y); \ } \ }; \ - REGISTER_XLA_OP(#Name, Name##Op); + REGISTER_XLA_OP(Name(#NAME), NAME##Op); // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); @@ -58,6 +58,27 @@ XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); XLAJIT_MAKE_UNARY(LogicalNot, b->LogicalNot(x)); XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); + +// Implements Banker's rounding: numbers that are equidistant between two +// integers are rounded towards even. +static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, + DataType dtype, + const xla::ComputationDataHandle& x) { + auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); + auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + + auto round_val = b->Floor(x); + auto fraction = b->Sub(x, round_val); + auto nearest_even_int = + b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x)))); + auto is_odd = b->Eq(nearest_even_int, one); + return b->Select(b->LogicalOr(b->Gt(fraction, half), + b->LogicalAnd(b->Eq(fraction, half), is_odd)), + b->Add(round_val, one), round_val); +} +XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); + XLAJIT_MAKE_UNARY(Rsqrt, b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); XLAJIT_MAKE_UNARY(Sigmoid, diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index c5b2bdaf2dc..f87586ba578 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -66,6 +66,7 @@ class UnpackOp : public XlaOpKernel { std::vector start_indices(input_shape.dims(), 0); std::vector limit_indices(input_shape.dims()); + std::vector strides(input_shape.dims(), 1); for (int i = 0; i < input_shape.dims(); ++i) { limit_indices[i] = input_shape.dim_size(i); } @@ -73,7 +74,8 @@ class UnpackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { start_indices[axis] = i; limit_indices[axis] = i + 1; - auto slice = ctx->builder()->Slice(input, start_indices, limit_indices); + auto slice = ctx->builder()->Slice(input, start_indices, limit_indices, + strides); // Reshape to drop the 'axis' dimension. auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes()); ctx->SetOutput(i, result); @@ -84,7 +86,7 @@ class UnpackOp : public XlaOpKernel { int axis_; }; -REGISTER_XLA_OP("Unpack", UnpackOp); +REGISTER_XLA_OP(Name("Unpack"), UnpackOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc new file mode 100644 index 00000000000..1b04b8b802c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class VarIsInitializedOp : public XlaOpKernel { + public: + explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle handle; + bool initialized = ctx->ReadVariableInput(0, &handle).ok(); + ctx->SetOutput(0, ctx->builder()->ConstantR0(initialized)); + } +}; +REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); + +class ReadVariableOp : public XlaOpKernel { + public: + explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle handle; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + ctx->SetOutput(0, handle); + } +}; +REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); +REGISTER_XLA_OP(Name("_UnsafeReadVariable"), ReadVariableOp); + +class AssignVariableOp : public XlaOpKernel { + public: + explicit AssignVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES_OK(ctx, + ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1))); + } +}; +REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp); + +class AssignAddVariableOp : public XlaOpKernel { + public: + explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle handle; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + handle = ctx->builder()->Add(handle, ctx->Input(1)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + } +}; +REGISTER_XLA_OP( + Name("AssignAddVariableOp").TypeConstraint("dtype", kNumericTypes), + AssignAddVariableOp); + +class AssignSubVariableOp : public XlaOpKernel { + public: + explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle handle; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + handle = ctx->builder()->Sub(handle, ctx->Input(1)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + } +}; +REGISTER_XLA_OP( + Name("AssignSubVariableOp").TypeConstraint("dtype", kNumericTypes), + AssignSubVariableOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 3e509375efb..fe08e83c239 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,14 +18,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -// Copies 'host_tensor' to an XLA Literal. Fails if the host_tensor has zero -// elements or is of an unsupported type. +// Copies 'host_tensor' to an XLA Literal. Fails if host_tensor is of an +// unsupported type. Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); // Copies 'literal' to 'host_tensor', which is allocated of type . diff --git a/tensorflow/compiler/tf2xla/op_registrations.cc b/tensorflow/compiler/tf2xla/op_registrations.cc deleted file mode 100644 index e32070efa32..00000000000 --- a/tensorflow/compiler/tf2xla/op_registrations.cc +++ /dev/null @@ -1,510 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Kernel registrations for XLA JIT devices. - -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" -#include "tensorflow/core/framework/op_kernel.h" - -namespace tensorflow { -namespace { - -// CPU JIT device registrations. - -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("_Arg").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("_ArrayToList")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("_ListToArray")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("_Retval").TypeConstraint("T", kCpuAllTypes)); - -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Abs").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Add").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("AddN").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("All")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Any")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("AvgPool").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("AvgPoolGrad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("BatchMatMul").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("BiasAdd").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("BiasAddV1").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("BiasAddGrad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("BroadcastGradientArgs")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Cast") - .TypeConstraint("SrcT", kCpuAllTypes) - .TypeConstraint("DstT", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Ceil").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Concat").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("ConcatV2") - .TypeConstraint("T", kCpuAllTypes) - .TypeConstraint("Tidx", DT_INT32)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("ConcatOffset")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Conv2D").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL( - DEVICE_CPU_XLA_JIT, - Name("Conv2DBackpropFilter").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL( - DEVICE_CPU_XLA_JIT, - Name("Conv2DBackpropInput").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL( - DEVICE_CPU_XLA_JIT, - Name("DepthwiseConv2dNative").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Diag").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("DiagPart").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Div").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("DynamicStitch").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Equal").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Exp").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("ExpandDims").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Fill").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Floor").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("FloorDiv").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("FloorMod").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Greater").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("GreaterEqual").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Inv").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Reciprocal").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("InvertPermutation").TypeConstraint("T", DT_INT32)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("L2Loss").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Less").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("LessEqual").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("LinSpace").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Log").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Log1p").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalAnd")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalNot")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalOr")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("LogSoftmax").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("LRN").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("LRNGrad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Maximum").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("MatMul").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("MatrixDiag").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("MatrixDiagPart").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Max").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("MaxPool").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("MaxPoolGrad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Mean").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Min").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Minimum").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Mod").TypeConstraint("T", kCpuIntTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Mul").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Neg").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("NotEqual").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Pack").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Pad").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Pow").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("PreventGradient").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Prod").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Range").TypeConstraint("Tidx", kCpuNumericTypes)); -// TODO(b/34339814): implement inverse erf for double types and update the -// type constraint. -REGISTER_XLA_KERNEL( - DEVICE_CPU_XLA_JIT, - Name("RandomStandardNormal").TypeConstraint("dtype", DT_FLOAT)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("RandomUniform")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("RandomUniformInt")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Rank")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("RealDiv").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Relu").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Relu6").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("ReluGrad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Relu6Grad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Reshape").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Rsqrt").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("RsqrtGrad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Select").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Shape")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("ShapeN")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Sigmoid").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("SigmoidGrad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Sign").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Size")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Slice").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Softmax").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL( - DEVICE_CPU_XLA_JIT, - Name("SoftmaxCrossEntropyWithLogits").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Softplus").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("SoftplusGrad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("SparseMatMul") - .TypeConstraint("Ta", kCpuFloatTypes) - .TypeConstraint("Tb", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Split").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("SplitV").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Square").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL( - DEVICE_CPU_XLA_JIT, - Name("SquaredDifference").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Squeeze").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Sqrt").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("StopGradient").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("StridedSlice").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("StridedSliceGrad").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Sub").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Sum").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("SymbolicGradient")); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Tanh").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("TanhGrad").TypeConstraint("T", kCpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Tile").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Transpose").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("TruncateDiv").TypeConstraint("T", kCpuIntTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("TruncateMod").TypeConstraint("T", kCpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Unpack").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, - Name("ZerosLike").TypeConstraint("T", kCpuNumericTypes)); - -REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_CPU_XLA_JIT, - Name("Const").TypeConstraint("dtype", - kCpuAllTypes)); -REGISTER_XLA_JIT_ONLY_KERNEL( - DEVICE_CPU_XLA_JIT, Name("Identity").TypeConstraint("T", kCpuAllTypes)); -REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_CPU_XLA_JIT, Name("NoOp")); - -// GPU JIT device registrations - -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("_Arg").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("_ArrayToList")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("_ListToArray")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("_Retval").TypeConstraint("T", kGpuAllTypes)); - -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Abs").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Add").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("AddN").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("All")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Any")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("AvgPool").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("AvgPoolGrad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("BatchMatMul").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("BiasAdd").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("BiasAddV1").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("BiasAddGrad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("BroadcastGradientArgs")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Cast") - .TypeConstraint("SrcT", kGpuAllTypes) - .TypeConstraint("DstT", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Ceil").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Concat").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("ConcatV2").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("ConcatOffset")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Conv2D").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL( - DEVICE_GPU_XLA_JIT, - Name("Conv2DBackpropFilter").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL( - DEVICE_GPU_XLA_JIT, - Name("Conv2DBackpropInput").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL( - DEVICE_GPU_XLA_JIT, - Name("DepthwiseConv2dNative").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Diag").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("DiagPart").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Div").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("DynamicStitch").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Equal").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Exp").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("ExpandDims").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Fill").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Floor").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("FloorDiv").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("FloorMod").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Greater").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("GreaterEqual").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Inv").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Reciprocal").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("InvertPermutation").TypeConstraint("T", DT_INT32)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("L2Loss").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Less").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("LessEqual").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("LinSpace").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Log").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Log1p").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalAnd")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalNot")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalOr")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("LogSoftmax").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("LRN").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("LRNGrad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Maximum").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("MatMul").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("MatrixDiag").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("MatrixDiagPart").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Max").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("MaxPool").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("MaxPoolGrad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Mean").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Min").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Minimum").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Mod").TypeConstraint("T", kGpuIntTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Mul").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Neg").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("NotEqual").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Pack").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Pad").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Pow").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("PreventGradient").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Prod").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Range").TypeConstraint("Tidx", kGpuNumericTypes)); -// TODO(b/31361304): disabled because of XLA bugs. -// REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("RandomStandardNormal")); -// REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("RandomUniform")); -// REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("RandomUniformInt")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Rank")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("RealDiv").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Relu").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Relu6").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("ReluGrad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Relu6Grad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Reshape").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Rsqrt").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("RsqrtGrad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Select").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Shape")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("ShapeN")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Sigmoid").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("SigmoidGrad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Sign").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Size")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Slice").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Softmax").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL( - DEVICE_GPU_XLA_JIT, - Name("SoftmaxCrossEntropyWithLogits").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Softplus").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("SoftplusGrad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("SparseMatMul") - .TypeConstraint("Ta", kGpuFloatTypes) - .TypeConstraint("Tb", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Split").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("SplitV").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Square").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL( - DEVICE_GPU_XLA_JIT, - Name("SquaredDifference").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Squeeze").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Sqrt").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("StopGradient").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("StridedSlice").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("StridedSliceGrad").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Sub").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Sum").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("SymbolicGradient")); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Tanh").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("TanhGrad").TypeConstraint("T", kGpuFloatTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Tile").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Transpose").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("TruncateDiv").TypeConstraint("T", kGpuIntTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("TruncateMod").TypeConstraint("T", kGpuNumericTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Unpack").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, - Name("ZerosLike").TypeConstraint("T", kGpuNumericTypes)); - -REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_GPU_XLA_JIT, - Name("Const").TypeConstraint("dtype", - kGpuAllTypes)); -REGISTER_XLA_JIT_ONLY_KERNEL( - DEVICE_GPU_XLA_JIT, Name("Identity").TypeConstraint("T", kGpuAllTypes)); -REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_GPU_XLA_JIT, Name("NoOp")); - -} // anonymous namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc index ce25d631271..2b0834fe7b6 100644 --- a/tensorflow/compiler/tf2xla/str_util.cc +++ b/tensorflow/compiler/tf2xla/str_util.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace str_util { -void ReplaceAll(string* text, StringPiece from, StringPiece to) { +static void ReplaceAll(string* text, StringPiece from, StringPiece to) { size_t pos = 0; while ((pos = text->find(from.data(), pos, from.size())) != string::npos) { text->replace(pos, from.size(), to.data(), to.size()); diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h index 4920b1a4d48..51f25009d70 100644 --- a/tensorflow/compiler/tf2xla/str_util.h +++ b/tensorflow/compiler/tf2xla/str_util.h @@ -29,10 +29,6 @@ limitations under the License. namespace tensorflow { namespace str_util { -// Replace all non-overlapping occurrences of from with to in-place in text. If -// from is empty, it matches at the beginning of the text and after every byte. -void ReplaceAll(string* text, StringPiece from, StringPiece to); - // Replace all non-overlapping occurrences of the given (from,to) pairs in-place // in text. If from is empty, it matches at the beginning of the text and after // every byte. Each (from,to) replacement pair is processed in the order it is diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc index f992007a345..8817f6902a8 100644 --- a/tensorflow/compiler/tf2xla/str_util_test.cc +++ b/tensorflow/compiler/tf2xla/str_util_test.cc @@ -25,36 +25,6 @@ limitations under the License. namespace tensorflow { namespace str_util { -class ReplaceAllTest : public ::testing::Test { - protected: - void ExpectReplaceAll(string text, StringPiece from, StringPiece to, - StringPiece want) { - ReplaceAll(&text, from, to); - EXPECT_EQ(text, want); - } -}; - -TEST_F(ReplaceAllTest, Simple) { - ExpectReplaceAll("", "", "", ""); - ExpectReplaceAll("", "", "X", "X"); - ExpectReplaceAll("", "", "XYZ", "XYZ"); - ExpectReplaceAll("banana", "", "", "banana"); - ExpectReplaceAll("banana", "", "_", "_b_a_n_a_n_a_"); - ExpectReplaceAll("banana", "", "__", "__b__a__n__a__n__a__"); - ExpectReplaceAll("banana", "a", "a", "banana"); - ExpectReplaceAll("banana", "a", "", "bnn"); - ExpectReplaceAll("banana", "a", "X", "bXnXnX"); - ExpectReplaceAll("banana", "a", "XX", "bXXnXXnXX"); - ExpectReplaceAll("banana", "an", "an", "banana"); - ExpectReplaceAll("banana", "an", "", "ba"); - ExpectReplaceAll("banana", "an", "X", "bXXa"); - ExpectReplaceAll("banana", "an", "XY", "bXYXYa"); - ExpectReplaceAll("banana", "an", "XYZ", "bXYZXYZa"); - ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "X", "foo X baz X"); - ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "ABCDEFGHIJKLMNOP", - "foo ABCDEFGHIJKLMNOP baz ABCDEFGHIJKLMNOP"); -} - class ReplaceAllPairsTest : public ::testing::Test { protected: void ExpectReplaceAllPairs( diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index ad3c9217440..1d0098591e3 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -18,20 +18,13 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace tensorflow { -const char* const DEVICE_CPU_XLA_JIT = "XLA_CPU_JIT"; -const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT"; - // The XlaCompilationAllocator doesn't actually back any Tensors with storage // buffers of values: instead for each Tensor it stores a // XlaExpression which corresponds to the XLA computation @@ -41,13 +34,12 @@ class XlaCompilationAllocator : public Allocator { XlaCompilationAllocator() {} ~XlaCompilationAllocator() override {} - string Name() override { return "tla_jit"; } + string Name() override { return "xla_compilation"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override { - // Regardless of the size requested, always allocate a - // XlaExpression. Respect the aligment request because there is - // alignment checking even for Tensors whose data is never - // accessed. + // Regardless of the size requested, always allocates an XlaExpression. + // Respects the aligment request because there is alignment checking even + // for Tensors whose data is never accessed. void* p = port::AlignedMalloc(sizeof(XlaExpression), alignment); XlaExpression* expression = reinterpret_cast(p); new (expression) XlaExpression(); @@ -80,11 +72,11 @@ class XlaCompilationAllocator : public Allocator { XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, DeviceType type) - : LocalDevice(options, - Device::BuildDeviceAttributes( - "", type, Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA JIT device ", type.type())), - cpu_allocator()), + : LocalDevice( + options, + Device::BuildDeviceAttributes( + "", type, Bytes(256 << 20), DeviceLocality(), + strings::StrCat("device: XLA compilation device ", type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} @@ -93,112 +85,38 @@ Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) { return allocator_.get(); } +void XlaCompilationDevice::Compute(OpKernel* op_kernel, + OpKernelContext* context) { + VLOG(1) << "XlaCompilationDevice::Compute " + << SummarizeNodeDef(op_kernel->def()); + auto* b = XlaContext::Get(context).builder(); + xla::OpMetadata metadata; + metadata.set_op_type(op_kernel->type_string()); + metadata.set_op_name(op_kernel->name()); + b->SetOpMetadata(metadata); + op_kernel->Compute(context); + b->ClearOpMetadata(); + VLOG(2) << "Done"; +} + Status XlaCompilationDevice::Sync() { return Status::OK(); } Status XlaCompilationDevice::MakeTensorFromProto( const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { return errors::InvalidArgument( - "Tla JIT Device should not parse tensor from proto"); + "XLACompilationDevice::MakeTensorFromProto should not be called"); } -// Is platform 'id' supported by XLA? -static bool IsPlatformSupported(perftools::gputools::Platform::Id id) { - auto platform = perftools::gputools::MultiPlatformManager::PlatformWithId(id); - if (!platform.ok()) return false; - return xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()).ok(); +XlaExpression::XlaExpression() = default; + +void XlaExpression::set_handle(const xla::ComputationDataHandle& h) { + handle_ = h; } -XlaOpRegistry::XlaOpRegistry() = default; -XlaOpRegistry::~XlaOpRegistry() = default; - -/* static */ void XlaOpRegistry::RegisterJitDevice( - const string& device_name, const string& jit_device_name, - bool requires_jit) { - XlaOpRegistry& registry = Instance(); - mutex_lock lock(registry.mutex_); - auto result = registry.jit_devices_.emplace( - device_name, std::make_pair(jit_device_name, requires_jit)); - CHECK(result.second || result.first->second.first == jit_device_name); -} - -/* static */ bool XlaOpRegistry::GetJitDevice(const string& device_name, - const string** jit_device_name, - bool* requires_jit) { - XlaOpRegistry& registry = Instance(); - - // Lazily register the CPU and GPU JIT devices the first time GetJitDevice is - // called. - static void* registration = [®istry]() { - mutex_lock lock(registry.mutex_); - if (IsPlatformSupported(perftools::gputools::host::kHostPlatformId)) { - registry.jit_devices_[DEVICE_CPU] = {DEVICE_CPU_XLA_JIT, false}; - } - if (IsPlatformSupported(perftools::gputools::cuda::kCudaPlatformId)) { - registry.jit_devices_[DEVICE_GPU] = {DEVICE_GPU_XLA_JIT, false}; - } - return nullptr; - }(); - (void)registration; - - mutex_lock lock(registry.mutex_); - auto it = registry.jit_devices_.find(device_name); - if (it == registry.jit_devices_.end()) return false; - if (jit_device_name) *jit_device_name = &it->second.first; - if (requires_jit) *requires_jit = it->second.second; - return true; -} - -void XlaOpRegistry::RegisterJitKernels() { - XlaOpRegistry& registry = Instance(); - mutex_lock lock(registry.mutex_); - - if (registry.jit_kernels_registered_) return; - registry.jit_kernels_registered_ = true; - - for (const auto& entry : registry.kernels_) { - for (const XlaKernel& k : entry.second) { - auto it = registry.ops_.find(k.kernel_def->op()); - CHECK(it != registry.ops_.end()) << "Missing XLA op registration for op " - << k.kernel_def->op(); - registry.kernel_registrars_.emplace_back( - new kernel_factory::OpKernelRegistrar(new KernelDef(*k.kernel_def), - "XlaJitOp", it->second)); - } - } -} - -std::vector XlaOpRegistry::DeviceKernels( - const string& jit_device_type) { - std::vector kernels; - XlaOpRegistry& registry = Instance(); - mutex_lock lock(registry.mutex_); - for (const XlaKernel& k : registry.kernels_.at(jit_device_type)) { - if (!k.jit_only) { - kernels.push_back(k.kernel_def.get()); - } - } - return kernels; -} - -XlaOpRegistry& XlaOpRegistry::Instance() { - static XlaOpRegistry* r = new XlaOpRegistry; - return *r; -} - -XlaOpRegistrar::XlaOpRegistrar(StringPiece name, - XlaOpRegistry::Factory factory) { - XlaOpRegistry& registry = XlaOpRegistry::Instance(); - mutex_lock lock(registry.mutex_); - CHECK(registry.ops_.emplace(name.ToString(), factory).second) - << "Duplicate XLA op registration " << name; -} - -XlaKernelRegistrar::XlaKernelRegistrar(bool jit_only, const KernelDef* def) { - XlaOpRegistry& registry = XlaOpRegistry::Instance(); - mutex_lock lock(registry.mutex_); - registry.kernels_[def->device_type()].push_back(XlaOpRegistry::XlaKernel{ - jit_only, std::unique_ptr(def)}); +void XlaExpression::set_constant_value(Tensor value) { + has_constant_value_ = true; + constant_value_ = std::move(value); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index f4b95b874b6..75630bee396 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -16,44 +16,21 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ -#include #include -#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { -// Names of the XLA JIT devices. These are not user-visible, and are used -// internally by the JIT to perform symbolic execution of a Tensorflow graph. - -extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT" -extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" - -constexpr std::array kCpuAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; -constexpr std::array kCpuIntTypes = {{DT_INT32, DT_INT64}}; -constexpr std::array kCpuFloatTypes = {{DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kCpuNumericTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}}; - -constexpr std::array kGpuAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; -constexpr std::array kGpuIntTypes = {{DT_INT32, DT_INT64}}; -constexpr std::array kGpuFloatTypes = {{DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kGpuNumericTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}}; - -// Class is declared and defined in tla_jit_device.cc, reference +// Class is defined in xla_compilation_device.cc, reference // included here only so the XlaCompilationDevice allocator_ member can be -// defined. +// declared. class XlaCompilationAllocator; // Deliberately don't register the device factory because we *never* @@ -75,6 +52,8 @@ class XlaCompilationDevice : public LocalDevice { Allocator* GetAllocator(AllocatorAttributes attr) override; + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + Status Sync() override; Status MakeTensorFromProto(const TensorProto& tensor_proto, @@ -85,130 +64,75 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -// Class that manages registrations of operators and devices for the XLA JIT. -// Not thread-safe. -class XlaOpRegistry { +struct XlaVariable { + // If this variable is visible externally, what was its argument number? + int arg_num = -1; + + // A descriptive name for the variable, used in error messages. + string name; + + // Current type and value of the variable. Uninitialized variables are + // represented by a default (zero) handle and type DT_INVALID. + // While the type of a variable is notionally fixed during execution, when + // a variable is first initialized we do not yet know its type, so we keep + // track of its type dynamically. + DataType type = DT_INVALID; + xla::ComputationDataHandle value; + + // Value of the variable at computation entry. Used to detect which + // variables have new values that need to be written back. + xla::ComputationDataHandle initial_value; + + // We treat TensorArrays as a Variable with some extra metadata. + + // 'tensor_array_size' stores the expected size of the TensorArray. We need + // to store this since sometimes TensorArrays must be initialized lazily since + // we do not know the element shape at construction time. + int64 tensor_array_size = -1; + + // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes + // to an XlaVariable containing the gradient TensorArrays. We store a pointer + // here since there should only be one gradient TensorArray per 'source' + // string, irrespective of the number of calls to TensorArrayGrad. + std::unordered_map tensor_array_gradient; +}; + +// A XlaExpression wraps an XLA computation. Each Tensor on an +// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor +// matches the shape of the subcomputation in the ComputationDataHandle. Each +// expression is either a constant, or a function of previously-compiled +// expressions. +class XlaExpression { public: - typedef OpKernel* (*Factory)(OpKernelConstruction*); + XlaExpression(); - // Registers 'jit_device_name' as the JIT device corresponding to - // 'device_name'. If 'requires_jit' is true, then operators placed on this - // device must be JIT-compiled. Dies if a conflicting registration already - // exists. - static void RegisterJitDevice(const string& device_name, - const string& jit_device_name, - bool requires_jit); + // handle() stores the XLA handle of the computation that the + // expression represents. + void set_handle(const xla::ComputationDataHandle& h); + const xla::ComputationDataHandle& handle() const { return handle_; } - // Returns the JIT device name associated with 'device_name', setting - // 'jit_device_name' and 'requires_jit', if they are not null. Returns false - // and leaves 'jit_device_name' and 'requires_jit' unchanged if no matching - // JIT device is registered. - static bool GetJitDevice(const string& device_name, - const string** jit_device_name, bool* requires_jit); + void set_constant_value(Tensor value); + bool has_constant_value() const { return has_constant_value_; } + const Tensor& constant_value() const { return constant_value_; } - // Registers all JIT kernels on JIT devices, if not already registered. - // Does nothing otherwise. - static void RegisterJitKernels(); - - // Returns KernelDefs for JIT ops registered on 'jit_device_type'. - // Does not include kernels registered using REGISTER_XLA_JIT_ONLY_KERNEL. - static std::vector DeviceKernels( - const string& jit_device_type); + void set_variable(XlaVariable* variable) { variable_ = variable; } + XlaVariable* variable() const { return variable_; } private: - friend class XlaKernelRegistrar; - friend class XlaOpRegistrar; + // The XLA handle of the expression's computation. + xla::ComputationDataHandle handle_; - static XlaOpRegistry& Instance(); + // If this expression is a constant with a known value, 'constant_value' is a + // host-memory Tensor containing the value. Used to avoid invoking XLA for + // expressions that are trivially constant. + bool has_constant_value_ = false; + Tensor constant_value_; - XlaOpRegistry(); - ~XlaOpRegistry(); + XlaVariable* variable_ = nullptr; // Not owned. - mutex mutex_; - - // Map from Tensorflow device names to the corresponding JIT device names. - std::unordered_map> jit_devices_ - GUARDED_BY(mutex_); - - // Map from operator name to OpKernel factory, populated by REGISTER_XLA_OP. - std::unordered_map ops_ GUARDED_BY(mutex_); - - // Have we already registered the JIT kernels on the JIT devices? - bool jit_kernels_registered_ = false; - - struct XlaKernel { - // Should this kernel be registered only on JIT devices, without a dummy - // kernel registered on the corresponding XLA device? - bool jit_only; - - // KernelDef as built by REGISTER_XLA_KERNEL. - std::unique_ptr kernel_def; - }; - - // Map from JIT device name to a vector of XLA kernel descriptors. - std::unordered_map> kernels_ - GUARDED_BY(mutex_); - - // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel - // registrations created by RegisterJitKernels() and RegisterDeviceKernels(). - std::vector> - kernel_registrars_ GUARDED_BY(mutex_); + TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression); }; -// REGISTER_XLA_OP() registers an XLA OpKernel by name, for example: -// REGISTER_XLA_OP("Add", AddOp); -// where 'AddOp' is the name of a JIT OpKernel class that implements "Add". -// -// We don't use a variadic macro here because we don't expect JIT operators to -// be templated. - -#define REGISTER_XLA_OP(NAME, OP) \ - REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) - -// REGISTER_XLA_KERNEL() associates an XLA OpKernel with a particular device and -// set of type constraints, e.g., -// REGISTER_XLA_KERNEL(DEVICE_XLA_CPU_JIT, -// Name("Relu").TypeConstraint("T", DT_FLOAT)); -// -// REGISTER_XLA_JIT_ONLY_KERNEL is similar to REGISTER_XLA_KERNEL(), but causes -// XlaOpRegistry::RegisterDeviceKernels() to ignore the kernel. - -#define REGISTER_XLA_KERNEL(DEVICE, BUILDER) \ - REGISTER_XLA_KERNEL_UNIQ_HELPER(__COUNTER__, DEVICE, BUILDER, false) - -#define REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE, BUILDER) \ - REGISTER_XLA_KERNEL_UNIQ_HELPER(__COUNTER__, DEVICE, BUILDER, true) - -// Implementation details. - -class XlaOpRegistrar { - public: - XlaOpRegistrar(StringPiece name, XlaOpRegistry::Factory factory); -}; - -#define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, NAME, OP) \ - REGISTER_XLA_OP_UNIQ(COUNTER, NAME, OP) - -#define REGISTER_XLA_OP_UNIQ(CTR, NAME, OP) \ - static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ - NAME, [](::tensorflow::OpKernelConstruction* context) \ - -> ::tensorflow::OpKernel* { return new OP(context); }); - -// Implementation details. -class XlaKernelRegistrar { - public: - XlaKernelRegistrar(bool jit_only, const KernelDef* def); -}; - -#define REGISTER_XLA_KERNEL_UNIQ_HELPER(COUNTER, DEVICE, BUILDER, JIT_ONLY) \ - REGISTER_XLA_KERNEL_UNIQ(COUNTER, DEVICE, BUILDER, JIT_ONLY) - -#define REGISTER_XLA_KERNEL_UNIQ(CTR, DEVICE, BUILDER, JIT_ONLY) \ - static ::tensorflow::XlaKernelRegistrar \ - xla_kernel_registrar__body__##CTR##__object( \ - JIT_ONLY, \ - ::tensorflow::register_kernel::BUILDER.Device(DEVICE).Build()); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 17adb9b1fdd..580ce3d802e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" @@ -37,36 +38,18 @@ namespace tensorflow { namespace { -Status CheckSignature(const DataTypeVector& tf_types, - const xla::Shape& xla_shape) { - if (xla::ShapeUtil::IsTuple(xla_shape)) { - if (xla::ShapeUtil::TupleElementCount(xla_shape) != tf_types.size()) { - return errors::Internal("XLA shape has ", - xla::ShapeUtil::TupleElementCount(xla_shape), - " elements while function has ", tf_types.size()); - } - for (int i = 0; i < tf_types.size(); ++i) { - xla::PrimitiveType type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[i], &type)); - if (type != - xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type()) { - return errors::Internal( - "element ", i, " has XLA type ", - xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type(), - " and TensorFlow type ", DataTypeString(tf_types[i])); - } - } - } else { - if (tf_types.size() != 1) { - return errors::Internal("Expected singleton type, got ", tf_types.size(), - " types"); - } - xla::PrimitiveType type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[0], &type)); - if (type != xla_shape.element_type()) { - return errors::Internal("singleton element has XLA type ", - xla_shape.element_type(), " and TensorFlow type ", - DataTypeString(tf_types[0])); +// Checks that arguments `args` match types `types`. +Status CheckSignature(const DataTypeVector& types, + const std::vector& args) { + if (args.size() != types.size()) { + return errors::Internal("Compilation arguments have ", args.size(), + " elements while function has ", types.size()); + } + for (int i = 0; i < types.size(); ++i) { + if (types[i] != args[i].type && types[i] != DT_RESOURCE) { + return errors::Internal( + "Argument ", i, " has declared type ", DataTypeString(args[i].type), + " but function parameter has type ", DataTypeString(types[i])); } } return Status::OK(); @@ -74,15 +57,39 @@ Status CheckSignature(const DataTypeVector& tf_types, } // namespace -XlaCompiler::XlaCompiler(const XlaCompiler::Options& options) - : client_(options.client), - allow_cpu_custom_calls_(options.allow_cpu_custom_calls), - local_executable_has_hybrid_result_( - options.local_executable_has_hybrid_result), - resolve_compile_time_constants_(options.resolve_compile_time_constants), +bool XlaCompiler::Argument::operator==( + const XlaCompiler::Argument& other) const { + if (std::tie(kind, type, shape, name, tensor_array_size) != + std::tie(other.kind, other.type, other.shape, other.name, + other.tensor_array_size)) { + return false; + } + if (constant_value.shape() != other.constant_value.shape()) { + return false; + } + return constant_value.tensor_data() == other.constant_value.tensor_data(); +} + +XlaCompiler::XlaCompiler(XlaCompiler::Options options) + : options_(options), + initialization_status_(Status::OK()), next_step_id_(1), - device_(new XlaCompilationDevice(SessionOptions(), options.device_type)), - device_mgr_({device_}) {} + device_( + new XlaCompilationDevice(SessionOptions(), *options_.device_type)), + device_mgr_({device_}) { + // We no longer need the device_type. + options_.device_type = nullptr; + + if (options_.populate_resource_manager) { + initialization_status_ = + (*options_.populate_resource_manager)(device_->resource_manager()); + } + + flib_runtime_.reset(NewFunctionLibraryRuntime( + &device_mgr_, Env::Default(), device_, options.graph_def_version, + options.flib_def, OptimizerOptions(), + nullptr /* custom_kernel_creator */)); +} XlaCompiler::~XlaCompiler() = default; @@ -91,102 +98,63 @@ int64 XlaCompiler::NextStepId() { return next_step_id_++; } +uint64 XlaCompiler::SignatureHash::operator()( + const std::pair>& signature) const { + return std::hash()(signature.first); +} + Status XlaCompiler::CompileFunction( - FunctionLibraryRuntime* flr, const NameAttrList& function, + const XlaCompiler::CompileOptions& options, const NameAttrList& function, const std::vector& args, XlaCompiler::CompilationResult* result) { - const string function_id = Canonicalize(function.name(), function.attr()); + const string function_id = + Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; - FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR( - flr->Instantiate(function.name(), function.attr(), &handle)); - - const FunctionBody* fbody = flr->GetFunctionBody(handle); - CHECK(fbody); - - return CompileFunctionBody(flr, *fbody, function_id, args, - /*use_tuple_arg=*/false, result); -} - -Status XlaCompiler::CompileSubComputation(FunctionLibraryRuntime* flr, - const NameAttrList& function, - const xla::Shape& input_shape, - const xla::Shape& output_shape, - xla::Computation* computation) { - const string function_id = Canonicalize(function.name(), function.attr()); - VLOG(1) << "XlaCompiler::CompileSubComputation " << function_id; - - FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR( - flr->Instantiate(function.name(), function.attr(), &handle)); - - const FunctionBody* fbody = flr->GetFunctionBody(handle); - CHECK(fbody); - - TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, input_shape)); - TF_RETURN_IF_ERROR(CheckSignature(fbody->ret_types, output_shape)); - - const bool use_tuple_arg = xla::ShapeUtil::IsTuple(input_shape); - - std::vector args(fbody->arg_types.size()); - if (use_tuple_arg) { - for (int i = 0; i < args.size(); ++i) { - xla::Shape xla_shape = - xla::ShapeUtil::GetTupleElementShape(input_shape, i); - args[i].type = fbody->arg_types[i]; - args[i].shape = XLAShapeToTensorShape(xla_shape); - args[i].parameter = i; - } - } else { - args[0].type = fbody->arg_types[0]; - args[0].shape = XLAShapeToTensorShape(input_shape); - args[0].parameter = 0; + auto it = cache_.find({function_id, args}); + if (it != cache_.end()) { + *result = it->second; + return Status::OK(); } - CompilationResult result; - TF_RETURN_IF_ERROR(CompileFunctionBody(flr, *fbody, function_id, args, - use_tuple_arg, &result)); + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flib_runtime_->Instantiate( + function.name(), AttrSlice(&function.attr()), &handle)); - if (!xla::ShapeUtil::Compatible(result.xla_output_shape, output_shape)) { - return errors::Internal("output shape mismatch from compilation"); - } - *computation = std::move(result.computation); + const FunctionBody* fbody = flib_runtime_->GetFunctionBody(handle); + CHECK(fbody); - return Status::OK(); -} + TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args)); -Status XlaCompiler::CompileFunctionBody( - FunctionLibraryRuntime* flr, const FunctionBody& fbody, - const string& function_id, const std::vector& args, - bool use_tuple_arg, XlaCompiler::CompilationResult* result) { - VLOG(1) << "XlaCompiler::CompileFunctionBody " << function_id; - - std::unique_ptr graph(new Graph(flr->GetFunctionLibraryDefinition())); - CopyGraph(*fbody.graph, graph.get()); + std::unique_ptr graph(new Graph(options_.flib_def)); + CopyGraph(*fbody->graph, graph.get()); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile( - strings::StrCat("xla_jit_raw_input_", function_id), *graph); + strings::StrCat("xla_compile_function_input_", function_id), *graph); } // Optimize the graph before running the compiler. - // TODO(pbar): The constant folder currently does not simplify int32 - // operations for devices other than CPU. OptimizerOptions opts; + opts.set_do_common_subexpression_elimination(true); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); - OptimizeGraph(flr, &graph); + optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(), + /*device=*/nullptr, &graph); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile( - strings::StrCat("xla_jit_final_graph_", function_id), *graph); + strings::StrCat("xla_compile_function_optimized_", function_id), + *graph); } VLOG(1) << "===================================================="; - TF_RETURN_IF_ERROR(CompileGraph(function_id, std::move(graph), flr, args, - use_tuple_arg, result)); + TF_RETURN_IF_ERROR( + CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; + cache_[{function_id, args}] = *result; return Status::OK(); } @@ -199,7 +167,7 @@ Status XlaCompiler::BuildExecutable( std::vector argument_layouts( result.xla_input_shapes.size()); for (int i = 0; i < result.xla_input_shapes.size(); ++i) { - argument_layouts[i] = &result.xla_input_shapes[i].second; + argument_layouts[i] = &result.xla_input_shapes[i]; } if (result.requires_runtime_context) { // The final arg is the XlaLocalRuntimeContext*. @@ -210,9 +178,10 @@ Status XlaCompiler::BuildExecutable( build_options.set_device_ordinal(local_client->default_device_ordinal()); build_options.set_platform(local_client->platform()); build_options.set_result_layout(result.xla_output_shape); - build_options.set_has_hybrid_result(local_executable_has_hybrid_result_); + build_options.set_has_hybrid_result( + options_.local_executable_has_hybrid_result); - auto compile_result = local_client->Compile(result.computation, + auto compile_result = local_client->Compile(*result.computation, argument_layouts, build_options); if (!compile_result.ok()) { return compile_result.status(); @@ -256,24 +225,12 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, std::unique_ptr exec(exec_ptr); // At this point ownership of the graph has been transferred to exec. - auto runner = [](Executor::Args::Closure c) { - // TODO(misard) Temporarily just schedule c eagerly while we - // decide what to do about the fact that the ComputationBuilder is - // thread-compatible, but we don't really want Op writers to have - // to remember to acquire a lock around every call to - // ComputationBuilder. One possibility is to add the (generally - // useful) ability to run a single-threaded Executor based on an - // option in LocalExecutorParams. Another is to automagically - // acquire a lock around ComputationBuilder calls using some - // wrapper or RAII funny business. - c(); - }; - // Run the graph symbolically, turning the graph into an XLA computation. Executor::Args exec_args; exec_args.step_id = step_id; exec_args.step_container = step_container.get(); - exec_args.runner = runner; + // Run all compilation kernels on the main thread. + exec_args.runner = [](Executor::Args::Closure c) { c(); }; TF_RETURN_WITH_CONTEXT_IF_ERROR( exec->Run(exec_args), "Conversion from TensorFlow graph to XLA computation failed."); @@ -283,84 +240,245 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, return cleanup_status; } +// Builds XLA computations for each of the arguments to the computation. +// `args` are the arguments to the computation. +Status BuildArguments(const std::vector& args, + bool use_tuple_arg, xla::ComputationBuilder* builder, + std::vector* context_args, + std::vector* input_mapping, + std::vector* input_shapes) { + context_args->resize(args.size()); + + // Argument numbers of arguments and variables that are to be passed to the + // XLA computation as runtime parameters. + std::vector parameters, variables; + parameters.reserve(args.size()); + variables.reserve(args.size()); + + for (std::vector::size_type i = 0; i < args.size(); + ++i) { + XlaContext::Argument& context_arg = (*context_args)[i]; + context_arg.name = args[i].name; + context_arg.value.constant_value = args[i].constant_value; + context_arg.value.type = args[i].type; + + switch (args[i].kind) { + case XlaCompiler::Argument::kVariable: + variables.push_back(i); + context_arg.is_variable = true; + context_arg.value.is_constant = false; + context_arg.tensor_array_size = args[i].tensor_array_size; + break; + case XlaCompiler::Argument::kParameter: + parameters.push_back(i); + context_arg.value.is_constant = false; + break; + case XlaCompiler::Argument::kUninitializedVariable: + context_arg.is_variable = true; + context_arg.value.is_constant = true; + context_arg.tensor_array_size = args[i].tensor_array_size; + break; + case XlaCompiler::Argument::kConstant: + context_arg.value.is_constant = true; + break; + case XlaCompiler::Argument::kInvalid: + return errors::Internal("Unreachable case in BuildArguments()"); + } + } + + // Append parameters containing variable values after the other runtime + // parameters. + parameters.insert(parameters.end(), variables.begin(), variables.end()); + if (parameters.empty()) { + return Status::OK(); + } + + input_shapes->resize(parameters.size()); + input_mapping->resize(parameters.size()); + for (std::vector::size_type i = 0; i < input_shapes->size(); ++i) { + const XlaCompiler::Argument& arg = args[parameters[i]]; + // Computes the shapes of non-constant arguments. + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type)); + xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(), + &(*input_shapes)[i]); + (*input_mapping)[i] = parameters[i]; + } + + if (use_tuple_arg) { + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes); + xla::ComputationDataHandle tuple = + builder->Parameter(0, tuple_shape, "arg_tuple"); + for (std::vector::size_type i = 0; i < input_shapes->size(); ++i) { + (*context_args)[parameters[i]].value.handle = + builder->GetTupleElement(tuple, i); + } + } else { + for (std::vector::size_type i = 0; i < input_shapes->size(); ++i) { + (*context_args)[parameters[i]].value.handle = + builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); + } + } + return Status::OK(); +} + +// Builds the XLA computation. +// +// `retvals` is the list of retvals produced by _Retval operators, in index +// order. `variable_map` is a map from variable ID numbers to XlaOpContext +// variable states, generated by the symbolic evaluation. +// If `has_side_effects` is true, the computation has side effects and should be +// built even if it has no outputs. +// If `return_updated_values_for_all_variables` is true, all variables will be +// included in `variable_updates`, regardless of whether their value changed. +// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. +// Sets `*variable_updates` to a description of variables whose values are +// written by the computation; the variable writes are the last +// `variable_updates.size()` return values from the computation. Each entry in +// `variable_updates` is a (input_index, type) pair, where `input_index` is the +// index of a resource variable argument to the computation, and `type` is the +// type of the final output. +Status BuildComputation( + const std::vector& retvals, + const std::vector>& variables, + bool has_side_effects, bool return_updated_values_for_all_variables, + xla::ComputationBuilder* builder, xla::Computation* computation, + int* num_nonconst_outputs, + std::vector* variable_updates) { + std::vector elems; + elems.reserve(retvals.size()); + for (const XlaContext::HandleOrConstant& retval : retvals) { + if (!retval.is_constant) { + elems.push_back(retval.handle); + } + } + *num_nonconst_outputs = elems.size(); + + // Add return values for variables whose values have changed. + std::vector arg_vars; + arg_vars.reserve(variables.size()); + for (const auto& var : variables) { + if (var->arg_num >= 0) { + arg_vars.push_back(var.get()); + } + } + std::sort(arg_vars.begin(), arg_vars.end(), + [](const XlaVariable* a, const XlaVariable* b) { + return a->arg_num < b->arg_num; + }); + + for (const XlaVariable* var : arg_vars) { + bool modified = var->value.handle() != var->initial_value.handle(); + if (return_updated_values_for_all_variables || modified) { + variable_updates->emplace_back(); + XlaCompiler::VariableUpdate& update = variable_updates->back(); + update.input_index = var->arg_num; + update.type = var->type; + update.modified = modified; + elems.push_back(var->value); + } + } + + if (!elems.empty() || has_side_effects) { + // Builds a empty tuple return value for computations that have side effects + // but have no return values. + xla::ComputationDataHandle handle = builder->Tuple(elems); + + // TODO(b/31775371): to workaround bug, we must build a no-op computation + // that is guaranteed to be constructed after all of the formal parameters + // to the computation. Once the bug is fixed, we could avoid tupling here. + if (elems.size() == 1) { + handle = builder->GetTupleElement(handle, 0); + } + + // Builds the XLA computation. + xla::StatusOr computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + } + return Status::OK(); +} + } // namespace -Status XlaCompiler::CompileGraph(string const& name, +Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, + string const& name, std::unique_ptr graph, - FunctionLibraryRuntime* flib, const std::vector& args, - bool use_tuple_arg, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; - // Converts the input shapes into xla::Shape instances. - result->xla_input_shapes.reserve(args.size()); - for (int i = 0; i < args.size(); ++i) { - if (args[i].parameter < 0) { - continue; - } - result->xla_input_shapes.push_back(std::make_pair(i, xla::Shape())); - TF_RETURN_IF_ERROR(TensorShapeToXLAShape( - args[i].type, args[i].shape, &result->xla_input_shapes.back().second)); - } + // Report the error here if initialization failed. + TF_RETURN_IF_ERROR(initialization_status_); - XlaContext* xla_context = - new XlaContext(this, client(), name, allow_cpu_custom_calls_, - resolve_compile_time_constants_); - core::ScopedUnref xla_context_unref(xla_context); + xla::ComputationBuilder builder(client(), name); + XlaContext* context = + new XlaContext(this, &builder, options_.allow_cpu_custom_calls, + options_.resolve_compile_time_constants); + core::ScopedUnref context_unref(context); - TF_RETURN_IF_ERROR(xla_context->BuildArguments(args, use_tuple_arg)); + result->tuple_arg = options.use_tuple_arg; - TF_RETURN_IF_ERROR( - ExecuteGraph(xla_context, std::move(graph), device_, flib, NextStepId())); + std::vector context_args; + TF_RETURN_IF_ERROR(BuildArguments(args, options.use_tuple_arg, &builder, + &context_args, &result->input_mapping, + &result->xla_input_shapes)); + context->set_args(std::move(context_args)); + + TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, + flib_runtime_.get(), NextStepId())); - std::vector compile_time_constants; int num_nonconst_outputs; - TF_RETURN_IF_ERROR(xla_context->CollectResults( - &result->computation, &result->requires_runtime_context, - &compile_time_constants, &num_nonconst_outputs)); + result->computation = std::make_shared(); + TF_RETURN_IF_ERROR(BuildComputation( + context->retvals(), context->variables(), context->has_side_effects(), + options.return_updated_values_for_all_variables, &builder, + result->computation.get(), &num_nonconst_outputs, + &result->variable_updates)); - VLOG(2) << "Outputs: constant: " << compile_time_constants.size() + result->requires_runtime_context = context->has_context_parameter(); + + // Tuple arguments and runtime context parameters are incompatible. + CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); + + VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - result->outputs.resize(compile_time_constants.size() + num_nonconst_outputs); - for (const auto& c : compile_time_constants) { - if (!c.status.ok()) { - Status constant_status = c.status; - errors::AppendToMessage(&constant_status, - "Failed evaluating constant XLA return " - "value ", - c.index); - return constant_status; + result->outputs.resize(context->retvals().size()); + for (std::vector::size_type i = 0; + i < context->retvals().size(); ++i) { + const XlaContext::HandleOrConstant& retval = context->retvals()[i]; + if (retval.is_constant) { + OutputDescription& output = result->outputs[i]; + output.shape = retval.constant_value.shape(); + output.is_constant = true; + output.constant_value = retval.constant_value; } - if (c.index >= result->outputs.size()) { - return errors::InvalidArgument("Invalid argument index ", c.index); - } - OutputDescription& output = result->outputs[c.index]; - output.shape = c.value.shape(); - output.is_constant = true; - output.constant_value = c.value; } - if (result->computation.IsNull()) { + if (result->computation->IsNull()) { return Status::OK(); } // Compute the output shapes, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(result->computation); + auto computation_shape = client()->GetComputationShape(*result->computation); if (!computation_shape.ok()) { return computation_shape.status(); } result->xla_output_shape.Swap( computation_shape.ValueOrDie()->mutable_result()); + VLOG(2) << "XLA output shape: " + << xla::ShapeUtil::HumanString(result->xla_output_shape); - auto num_non_constant_outputs = + auto num_computation_outputs = (xla::ShapeUtil::IsTuple(result->xla_output_shape)) ? xla::ShapeUtil::TupleElementCount(result->xla_output_shape) : 1; // Tensorflow expects a major-to-minor order of results. - if (1 == num_non_constant_outputs) { + if (1 == num_computation_outputs) { xla::Shape& s = result->xla_output_shape; auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major(); minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0); @@ -375,20 +493,37 @@ Status XlaCompiler::CompileGraph(string const& name, // Converts the output shapes to TensorShapes. int computation_output = 0; - for (int i = 0; i < result->outputs.size(); ++i) { - if (!result->outputs[i].is_constant) { - CHECK_LT(computation_output, num_non_constant_outputs); - if (num_non_constant_outputs > 1) { - result->outputs[i].shape = + for (std::vector::size_type i = 0; + i < context->retvals().size(); ++i) { + const XlaContext::HandleOrConstant& retval = context->retvals()[i]; + if (!retval.is_constant) { + CHECK_LT(computation_output, num_computation_outputs); + OutputDescription& output = result->outputs[i]; + output.is_constant = false; + if (num_computation_outputs > 1) { + output.shape = XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( result->xla_output_shape, computation_output)); } else { - result->outputs[i].shape = - XLAShapeToTensorShape(result->xla_output_shape); + output.shape = XLAShapeToTensorShape(result->xla_output_shape); } ++computation_output; } } + + for (std::vector::size_type i = 0; + i < result->variable_updates.size(); ++i) { + if (num_computation_outputs > 1) { + result->variable_updates[i].shape = + XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( + result->xla_output_shape, computation_output)); + } else { + CHECK_EQ(0, computation_output); + result->variable_updates[i].shape = + XLAShapeToTensorShape(result->xla_output_shape); + } + ++computation_output; + } return Status::OK(); } @@ -397,7 +532,7 @@ Status XlaCompiler::GetChannelHandle(const string& key, mutex_lock lock(mu_); auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { - TF_ASSIGN_OR_RETURN(result.first->second, client_->CreateChannelHandle()); + TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle()); } *channel = result.first->second; VLOG(1) << "Channel: " << key << " " << channel->DebugString(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index df40af3bbd4..13143055325 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -34,36 +35,95 @@ namespace tensorflow { // It does a symbolic execution of the graph starting from specific input // shapes, using a JIT device to convert operators into XLA computations. // -// It is typically invoked from an `_XlaLaunch` operator once the shapes -// of all input parameters to the computation are known. This is +// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the +// shapes of all input parameters to the computation are known. This is // because the symbolic execution requires known shapes for all operations. +// +// XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes, +// and return outputs via _Retval nodes. +// +// The XlaCompiler requires one Argument struct for each _Arg index, that +// describes each argument. Arguments can be compile-time constants +// (kind kConstant), run-time parameters (kind kParameter), or resource +// variables (kinds kVariable and kUninitializedVariable). +// +// Only kParameter and kVariable arguments become runtime parameters to the +// generated XLA computation. The XLA computation will have run-time parameters +// in the following order: +// +---------------------+-----------------------------------------+ +// | kParameter values | Initial values of kVariable arguments | +// +---------------------+-----------------------------------------+ +// Within each block, the arguments are arranged by the _Arg index from which +// they were derived. +// If `Options::requires_runtime_context` is true, then an additional runtime +// context argument is passed as a final argument. +// +// The run-time outputs of the XLA computation are arranged in the following +// order: +// +------------------+-----------------------------------------+ +// | _Retval values | Updated values of kVariable arguments | +// +------------------+-----------------------------------------+ +// _Retval values are ordered by _Retval index, whereas kVariable values are +// ordered by the original _Arg position of the variable. +// +// In both inputs and outputs, kVariable values are placed the end. When +// emitting While loop bodies, we must ensure that the loop body has +// identical input and output signatures. By moving variable values +// to the end of the argument list and using the +// `return_updated_values_for_all_variables` option, we can ensure that the +// input and output values of variables appear at the same positions. + class XlaCompiler { public: // Describes how to derive the value of each _Arg node in the graph/function - // being compiled. Each argument must be either a parameter of the generated - // XLA computation (parameter >= 0), or a compile time constant - // (parameter < 0). + // being compiled. There must be one Argument for each _Arg index. struct Argument { - // The type of the argument. + enum Kind { + // Default value; not a valid kind. + kInvalid, + + // Argument is a compile-time constant. No associated runtime parameter. + kConstant, + + // Argument is a variable that has not been initialized yet. No associated + // runtime parameter. + kUninitializedVariable, + + // Argument is a variable that already has a value set. Expects a runtime + // parameter containing the current value. + kVariable, + + // Argument is a run-time parameter. + kParameter, + }; + + Kind kind = kInvalid; + + // The type of the argument. If the argument is a resource variable, this + // is the type of the variable's value, not DT_RESOURCE. DataType type; - // The shape of the argument. + // The shape of the argument. If the argument is a resource variable, this + // is the shape of the variable's value. TensorShape shape; - // The parameter number of this argument to the XLA computation. < 0 - // means this is a compile-time constant argument. - int parameter; - // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. Tensor constant_value; // The name of this argument, used for debugging. string name; + + // For a kVariable or kUninitializedVariable corresponding to a TensorArray, + // what is the tensor array's declared size? + int64 tensor_array_size = -1; + + bool operator==(const Argument& other) const; }; struct OutputDescription { - // Shape of the output. + // Type and shape of the output. + DataType type; TensorShape shape; // Constant output value, if known to be constant at JIT compilation time. @@ -72,37 +132,69 @@ class XlaCompiler { Tensor constant_value; }; + // Describes a variable write side effect of the computation. + struct VariableUpdate { + // Index of the input that contains the variable resource to write to. + int input_index; + + // Type and shape of the tensor to be written back. + DataType type; + TensorShape shape; + + // Was the value of the variable modified by the computation? + // (Always true, unless `return_updated_values_for_all_variables` is true.) + bool modified; + }; + struct CompilationResult { - // Vector of (Tensorflow input number, XLA shape) pairs that describe - // the arguments of the compiled XLA computation. (Because of constant - // inputs, the arguments to the XLA computation are a subset of the - // inputs passed to the JIT.) - std::vector> xla_input_shapes; + // Vector that maps from the parameters of the XLA computation to their + // original argument positions. To handle compile-time constant inputs and + // variables, the parameters to the XLA computation may be a subset of the + // original arguments, and are not necessarily in the same order.) + std::vector input_mapping; // Does the computation require the local runtime context to be passed as // the last argument? bool requires_runtime_context = false; - // Output shape in XLA format. This is a tuple if and only if - // there are multiple non-constant outputs. + // Input shapes of the computation. + std::vector xla_input_shapes; + + // Should the arguments be packed into a single tuple? + bool tuple_arg; + + // Output shape in XLA format. The output shape is a tuple if and only if + // the number of non-constant outputs is not equal to 1. xla::Shape xla_output_shape; // TensorFlow shapes of outputs, together with the values of any // constant arguments. Vector indexed by Tensorflow _Retval number, - // containing both constant and non-constant arguments. + // containing both constant and non-constant results. std::vector outputs; + // Variables whose values were updated by the computation, ordered + // by return value position. Variable updates follow the non-constant + // results in the outputs of XLA computation. + std::vector variable_updates; + // The XLA computation built from the tensorflow subgraph. May be null // if the output consists solely of compile-time constants. - xla::Computation computation; + std::shared_ptr computation; }; struct Options { - // Name of the compilation device to use. - DeviceType device_type = DeviceType(""); + // Name of the compilation device to use. Needs to be live only during + // XlaCompiler's constructor. + const DeviceType* device_type = nullptr; xla::Client* client = nullptr; + // Function library in which to find function definitions. Must be non-null. + const FunctionLibraryDefinition* flib_def = nullptr; + + // The graph def version to be compiled. + int graph_def_version = TF_GRAPH_DEF_VERSION; + // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed // to the computation. @@ -119,22 +211,43 @@ class XlaCompiler { // as Tensors at compile-time, rather than as run-time outputs of the // computation. bool resolve_compile_time_constants = true; + + // If not nullptr, populate_resource_manager is called with the + // compilation device's resource manager when the compilation + // device is created, and can be used to create metadata objects + // that can be accessed by XLA op kernels. + std::function* populate_resource_manager = nullptr; }; - explicit XlaCompiler(const Options& options); + explicit XlaCompiler(Options options); ~XlaCompiler(); + // Options pertaining to an individual call to CompileGraph() or + // CompileFunction(). + struct CompileOptions { + // If `use_tuple_arg` is true, a single tuple parameter will be used for all + // arguments; if false, each argument gets its own parameter. + bool use_tuple_arg = false; + + // If 'return_updated_values_for_all_variables' is true, then updated + // values of all resource variables arguments will be included in the + // 'variable_updates' of the computation, even if the variable was not + // modified by the computation. Used when compiling loop bodies to ensure + // the input and output signatures match. + bool return_updated_values_for_all_variables = false; + }; + // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. // `args` describes the arguments to the function, each of which must either - // be a parameter to the XLA computation or a compile-time constant. - // Writes the compiled output to `result`. + // be a runtime-parameter to the XLA computation, a compile-time constant, or + // a resource variable. Writes the compiled output to `result`. // // The generated XLA computation returns a tuple containing only the // non-constant outputs as a function of the input arguments. Constant // arguments are returned as host memory tensors in the output list and are // not included in the XLA computation's outputs. The XLA computation is - // null if there are no data-dependent outputs. - Status CompileFunction(FunctionLibraryRuntime* flr, + // null if there are no data-dependent outputs and no side effects. + Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, const std::vector& args, CompilationResult* result); @@ -142,43 +255,21 @@ class XlaCompiler { // Compiles a tensorflow::Graph into an xla::Computation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. - // If `use_tuple_arg` is true, the compilation takes all of its arguments as - // a single tuple. - Status CompileGraph(string const& name, std::unique_ptr graph, - FunctionLibraryRuntime* flr, - const std::vector& args, bool use_tuple_arg, + Status CompileGraph(const CompileOptions& options, string const& name, + std::unique_ptr graph, + const std::vector& args, CompilationResult* result); - // Helper function that compiles a function to an XLA computation suitable - // for use as a subroutine in other Computations, e.g., the body of a - // While loop. - // - // The emitted Computation takes a single input parameter with - // input_shape. If this is a tuple then the tuple element shapes - // must match the types of the function's _Arg nodes. If input_shape - // is not a tuple then the function must have a single _Arg node - // with the same type as input_shape. The shapes of the _Arg values - // will be compiled to match input_shape. - // - // The emitted Computation also returns a single value. If output_shape is a - // tuple the tuple elements' types and shapes must match the compiled - // function's _Retval nodes. If output_shape is not a tuple the - // function must have a single _Retval node with the correct type - // (and shape after compilation). - Status CompileSubComputation(FunctionLibraryRuntime* flr, - const NameAttrList& fn_name_attrs, - const xla::Shape& input_shape, - const xla::Shape& output_shape, - xla::Computation* computation); - - // Takes <*result>, which has been compiled from a Tensorflow subgraph to a + // Takes `result` which has been compiled from a Tensorflow subgraph to a // XLA computation already, and generates an XLA LocalExecutable `executable`. Status BuildExecutable(const CompilationResult& result, std::unique_ptr* executable); - xla::Client* client() const { return client_; } + const Options& options() const { return options_; } + xla::Client* client() const { return options_.client; } XlaCompilationDevice* device() const { return device_; } const DeviceMgr* device_mgr() const { return &device_mgr_; } + FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_.get(); } // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -187,17 +278,10 @@ class XlaCompiler { Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); private: - // Does the real work of Compile() and CompileToComputation(). - Status CompileFunctionBody(FunctionLibraryRuntime* flr, - const FunctionBody& function_body, - const string& name, - const std::vector& args, - bool use_tuple_arg, CompilationResult* result); + Options options_; - xla::Client* client_; // Not owned. - const bool allow_cpu_custom_calls_; - const bool local_executable_has_hybrid_result_; - const bool resolve_compile_time_constants_; + // Status set to non-OK in the constructor if initialization fails. + Status initialization_status_; // Returns the next step sequence number. int64 NextStepId(); @@ -210,6 +294,17 @@ class XlaCompiler { XlaCompilationDevice* device_; // Owned by device_mgr_ DeviceMgr device_mgr_; + std::unique_ptr flib_runtime_; + + struct SignatureHash { + uint64 operator()( + const std::pair>& signature) const; + }; + + std::unordered_map>, + CompilationResult, SignatureHash> + cache_; + std::unordered_map channels_ GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index ca7c0b17b8c..58d74057d10 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -17,12 +17,14 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -33,12 +35,73 @@ limitations under the License. namespace tensorflow { namespace { +// Helper class to test the ability to pass resources through to XLA +// compiled kernels. +class DummyResourceForTest : public ResourceBase { + public: + string DebugString() override { return "dummy"; } + void Increment() { ++value_; } + int Get() { return value_; } + + private: + int value_ = 0; +}; + +class DummyReadResourceOp : public XlaOpKernel { + public: + explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ResourceMgr* rm = ctx->op_kernel_context()->resource_manager(); + OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); + DummyResourceForTest* dummy; + OP_REQUIRES_OK(ctx, rm->Lookup( + rm->default_container(), "dummy", &dummy)); + dummy->Increment(); + dummy->Unref(); + + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +class DummyReadResourceCC { + public: + DummyReadResourceCC(const Scope& scope, const Input& value) { + if (!scope.ok()) return; + auto _value = ops::AsNodeOut(scope, value); + if (!scope.ok()) return; + Node* ret; + const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource"); + auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value); + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + if (!scope.ok()) return; + this->output_ = Output(ret, 0); + } + Node* node() const { return output_.node(); } + + Output output_; +}; + +REGISTER_OP("DummyReadResource") + .Input("input: int32") + .Output("output: int32") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); + +REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp); + class XlaCompilerTest : public ::testing::Test { protected: + XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} + void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); - XlaOpRegistry::RegisterJitKernels(); + XlaOpRegistry::RegisterCompilationKernels(); FunctionDefLibrary flib; flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); @@ -46,19 +109,13 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + options.device_type = &cpu_device_type_; options.client = client_; + options.flib_def = flib_def_.get(); return options; } - std::unique_ptr BuildFunctionLibraryRuntime( - const XlaCompiler& compiler) { - return std::unique_ptr(NewFunctionLibraryRuntime( - compiler.device_mgr(), /*env=*/nullptr, compiler.device(), - TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), - /*custom_kernel_creator=*/nullptr)); - } - + DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -66,16 +123,15 @@ class XlaCompilerTest : public ::testing::Test { // Tests compilation of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { XlaCompiler compiler(DefaultOptions()); - auto flr = BuildFunctionLibraryRuntime(compiler); std::unique_ptr graph(new Graph(OpRegistry::Global())); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("add", std::move(graph), flr.get(), - /*args=*/{}, /*use_tuple_arg=*/false, - &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), + /*args=*/{}, &result)); // No computation should be generated. - EXPECT_EQ(0, result.computation.handle().handle()); + EXPECT_EQ(0, result.computation->handle().handle()); } // Tests compilation and execution of a graph that adds two tensors. @@ -91,20 +147,19 @@ TEST_F(XlaCompilerTest, Simple) { // Builds a description of the arguments. std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; args[0].shape = TensorShape({2}); - args[0].parameter = 0; + args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; args[1].shape = TensorShape({2}); - args[1].parameter = 1; // Compiles the graph. XlaCompiler compiler(DefaultOptions()); - auto flr = BuildFunctionLibraryRuntime(compiler); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("add", std::move(graph), flr.get(), args, - /*use_tuple_arg=*/false, &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); // Tests that the generated computation works. std::unique_ptr param0_literal = @@ -118,7 +173,7 @@ TEST_F(XlaCompilerTest, Simple) { std::unique_ptr actual = client_ - ->Execute(result.computation, {param0_data.get(), param1_data.get()}) + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -144,23 +199,22 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Builds a description of the arguments. std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; args[0].shape = TensorShape({2}); - args[0].parameter = 0; { // Compiles the graph, with resolve_compile_time_constants enabled. XlaCompiler::Options options = DefaultOptions(); options.resolve_compile_time_constants = true; XlaCompiler compiler(options); - auto flr = BuildFunctionLibraryRuntime(compiler); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("constants", std::move(graph_copy), - flr.get(), args, /*use_tuple_arg=*/false, + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "constants", std::move(graph_copy), args, &result)); ASSERT_EQ(2, result.outputs.size()); @@ -176,7 +230,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_->Execute(result.computation, {param0_data.get()}) + client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -191,14 +245,13 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { XlaCompiler::Options options = DefaultOptions(); options.resolve_compile_time_constants = false; XlaCompiler compiler(options); - auto flr = BuildFunctionLibraryRuntime(compiler); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("constants", std::move(graph_copy), - flr.get(), args, /*use_tuple_arg=*/false, + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "constants", std::move(graph_copy), args, &result)); ASSERT_EQ(2, result.outputs.size()); @@ -212,7 +265,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_->Execute(result.computation, {param0_data.get()}) + client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -227,5 +280,44 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { } } +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, ResourceManager) { + // Builds a graph that calls the dummy resource Op. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = DummyReadResourceCC(scope.WithOpName("B"), a); + auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the argument. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + + DummyResourceForTest* resource = new DummyResourceForTest(); + + // Compiles the graph. + auto options = DefaultOptions(); + std::function populate_function = + [resource](ResourceMgr* rm) { + resource->Ref(); + return rm->Create(rm->default_container(), "dummy", resource); + }; + options.populate_resource_manager = &populate_function; + XlaCompiler compiler(options); + + EXPECT_EQ(0, resource->Get()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", + std::move(graph), args, &result)); + + EXPECT_EQ(1, resource->Get()); + + resource->Unref(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index a770271628c..4440b530696 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -31,25 +33,9 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" namespace tensorflow { -XlaExpression::XlaExpression() : has_constant_value_(false) {} - -void XlaExpression::set_handle(const xla::ComputationDataHandle& h) { - handle_ = h; -} -const xla::ComputationDataHandle& XlaExpression::handle() const { - return handle_; -} - -void XlaExpression::set_constant_value(Tensor value) { - has_constant_value_ = true; - constant_value_ = std::move(value); -} - const char XlaContext::kXlaContextResourceName[] = "_xla_context"; // Looks up the context associated with the current step. It is stored @@ -68,145 +54,37 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context"; return *context; } -Status XlaContext::BuildArguments(std::vector args, - bool use_tuple_arg) { +/* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) { + return Get(ctx->op_kernel_context()); +} + +void XlaContext::set_args(std::vector args) { args_ = std::move(args); - use_tuple_arg_ = use_tuple_arg; - - // Compute the number of parameters, verify that they are sequential starting - // from 0 - num_parameters_ = 0; - for (const XlaCompiler::Argument& arg : args_) { - if (arg.parameter < 0) continue; - if (num_parameters_ != arg.parameter) { - return errors::InvalidArgument( - "Parameter numbers to JIT compilation are not consecutive starting " - "from 0"); - } - ++num_parameters_; - - if (arg.shape.num_elements() == 0) { - return errors::InvalidArgument( - "Non-constant argument must have a non-zero number of elements."); - } - } - if (num_parameters_ == 0) return Status::OK(); - - parameters_.resize(num_parameters_); - - std::vector parameter_shapes(num_parameters_); - for (int i = 0; i < args_.size(); ++i) { - const XlaCompiler::Argument& arg = args_[i]; - if (arg.parameter < 0) continue; - // Computes the shapes of non-constant arguments. - xla::PrimitiveType type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type)); - xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(), - ¶meter_shapes[arg.parameter]); - } - - if (use_tuple_arg_ && num_parameters_ > 0) { - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(parameter_shapes); - xla::ComputationDataHandle tuple = - builder().Parameter(0, tuple_shape, "arg_tuple"); - for (int i = 0; i < args_.size(); ++i) { - const XlaCompiler::Argument& arg = args_[i]; - if (arg.parameter < 0) continue; - parameters_[arg.parameter] = - builder().GetTupleElement(tuple, arg.parameter); - } - } else { - for (int i = 0; i < args_.size(); ++i) { - const XlaCompiler::Argument& arg = args_[i]; - if (arg.parameter < 0) continue; - parameters_[arg.parameter] = - builder().Parameter(arg.parameter, parameter_shapes[arg.parameter], - strings::StrCat("arg", i)); - } - } - return Status::OK(); } -Status XlaContext::CollectResults( - xla::Computation* computation, bool* requires_runtime_context, - std::vector* compile_time_constants, - int* num_nonconst_outputs) { - mutex_lock l(mu_); - - xla::ComputationDataHandle handle; - if (retval_.empty() && has_side_effects_) { - // Build a empty tuple return value for computations that have side effects - // but have no return values. - handle = builder().Tuple({}); - } else if (retval_.size() == 1) { - handle = retval_[0].second; - - // TODO(b/31775371): to workaround bug, add a no-op computation that is - // guaranteed to be constructed after all of the formal parameters to the - // computation. - handle = builder().GetTupleElement(builder().Tuple({handle}), 0); - - // Ensure that the retval is returned even if another computation - // was mistakenly placed on the ComputationBuilder. - TF_CHECK_OK(builder().SetReturnValue(handle)); - } else if (retval_.size() > 1) { - // There is at least one data-dependent expression: combine them - // into a Tuple in index order before compiling. - VLOG(1) << "Making the retval tuple."; - std::sort(retval_.begin(), retval_.end(), - [](const std::pair& a, - const std::pair& b) { - return a.first < b.first; - }); - std::vector elems; - elems.reserve(retval_.size()); - for (const std::pair& r : retval_) { - elems.push_back(r.second); - } - // Make a tuple from the vector of handles. - handle = builder().Tuple(elems); - } - - if (handle.handle() > 0) { - // Builds the XLA computation. - xla::StatusOr computation_status = builder().Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); - } - - // Make sure the compile time constants are in RetVal index order. - std::sort(compile_time_constant_.begin(), compile_time_constant_.end(), - [](const ConstRetVal& a, const ConstRetVal& b) { - return a.index < b.index; - }); - - // Fill in the result details and return. - *compile_time_constants = std::move(compile_time_constant_); - *requires_runtime_context = has_context_parameter_; - *num_nonconst_outputs = retval_.size(); - return Status::OK(); -} - -XlaContext::XlaContext(XlaCompiler* compiler, xla::Client* client, - const string& computation_name, +XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants) : compiler_(compiler), - xla_builder_(client, computation_name), + builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants) {} const xla::ComputationDataHandle& XlaContext::GetOrCreateRuntimeContextParameter() { - mutex_lock lock(mu_); CHECK(allow_cpu_custom_calls_); - CHECK(!use_tuple_arg_); if (has_context_parameter_) return context_parameter_; has_context_parameter_ = true; - context_parameter_ = xla_builder_.Parameter( - num_parameters_, xla::ShapeUtil::MakeOpaqueShape(), "tf_context"); + + // Allocate the next available parameter for the context parameter. + int num_parameters = 0; + for (const Argument& arg : args_) { + if (!arg.value.is_constant) { + ++num_parameters; + } + } + context_parameter_ = builder_->Parameter( + num_parameters, xla::ShapeUtil::MakeOpaqueShape(), "tf_context"); return context_parameter_; } @@ -214,72 +92,61 @@ string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. -void XlaContext::AddRetval(int retval_index, +void XlaContext::AddRetval(int retval_index, DataType type, const xla::ComputationDataHandle& handle) { VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; - // Add the return value to the list being built up. The executor - // is multi-threaded so this has to happen under the - // lock. - mutex_lock l(mu_); - retval_.emplace_back(retval_index, handle); + // Add the return value to the list being built up. + if (retvals_.size() <= retval_index) { + retvals_.resize(retval_index + 1); + } + retvals_[retval_index].is_constant = false; + retvals_[retval_index].type = type; + retvals_[retval_index].handle = handle; } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, const xla::Literal& literal) { VLOG(1) << "Adding retval index " << retval_index << " with non-data-dependent tensor to XLA computation"; + if (retvals_.size() <= retval_index) { + retvals_.resize(retval_index + 1); + } + retvals_[retval_index].type = dtype; if (resolve_compile_time_constants_) { - ConstRetVal value; - value.index = retval_index; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value.value)); - mutex_lock l(mu_); - compile_time_constant_.push_back(std::move(value)); + retvals_[retval_index].is_constant = true; + TF_RETURN_IF_ERROR(LiteralToHostTensor( + literal, dtype, &retvals_[retval_index].constant_value)); } else { - mutex_lock l(mu_); - retval_.emplace_back(retval_index, xla_builder_.ConstantLiteral(literal)); + retvals_[retval_index].is_constant = false; + retvals_[retval_index].handle = builder_->ConstantLiteral(literal); } return Status::OK(); } void XlaContext::AddSideEffects() { - mutex_lock lock(mu_); has_side_effects_ = true; } -/* static */ const XlaExpression* XlaContext::CastExpressionFromTensor( - const Tensor& tensor) { - const XlaExpression* expression = - reinterpret_cast(tensor.tensor_data().data()); - CHECK_NE(expression->handle().handle(), 0); - VLOG(1) << "Fetched T" << expression->handle().handle(); - return expression; -} +xla::ComputationBuilder* XlaContext::builder() { return builder_; } -/* static */ XlaExpression* XlaContext::CastExpressionFromUninitializedTensor( - Tensor* tensor) { - const XlaExpression* expression = - reinterpret_cast(tensor->tensor_data().data()); - CHECK_EQ(expression->handle().handle(), 0); - return const_cast(expression); +Status XlaContext::CreateVariable(int arg_num, string name, DataType type, + const xla::ComputationDataHandle& handle, + XlaVariable** variable) { + variables_.emplace_back(new XlaVariable); + *variable = variables_.back().get(); + XlaVariable& var = **variable; + var.arg_num = arg_num; + var.name = std::move(name); + var.type = type; + var.initial_value = var.value = handle; + return Status::OK(); } -/* static */ const XlaExpression* XlaContext::GetExpressionFromTensor( - const Tensor& tensor) { - return CastExpressionFromTensor(tensor); -} - -/* static */ const xla::ComputationDataHandle& -XlaContext::GetComputationFromTensor(const Tensor& tensor) { - return CastExpressionFromTensor(tensor)->handle(); -} - -xla::ComputationBuilder& XlaContext::builder() { return xla_builder_; } - const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Max() for " << type_string; - xla::ComputationBuilder b(builder().client(), "max<" + type_string + ">"); + xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -293,7 +160,7 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { return LookupOrCreate(type, &add_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Add() for " << type_string; - xla::ComputationBuilder b(builder().client(), "add<" + type_string + ">"); + xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -307,14 +174,19 @@ const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) { return LookupOrCreate(type, &sigmoid_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Sigmoid() for " << type_string; - xla::ComputationBuilder b(builder().client(), + xla::ComputationBuilder b(builder()->client(), "sigmoid<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto one = b.ConstantLiteral(xla::LiteralUtil::One(xla_type)); - auto minus_one = b.Neg(one); - b.Div(one, b.Add(b.Exp(b.Mul(x, minus_one)), one)); + // Clamp the inputs to the range [-18, 18] since anything outside + // this range is 0.0f or 1.0f in single-precision. We must clamp the range + // of x to avoid incorrect outputs due to fast-math optimizations for large + // negative x. + x = b.Clamp(XlaHelpers::IntegerLiteral(&b, type, -18), x, + XlaHelpers::IntegerLiteral(&b, type, 18)); + auto one = XlaHelpers::One(&b, type); + b.Div(one, b.Add(b.Exp(b.Neg(x)), one)); return b.Build().ConsumeValueOrDie(); }); } @@ -323,7 +195,6 @@ const xla::Computation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, const std::function& create) { { - mutex_lock l(mu_); const auto& entry = (*out)[type]; if (!entry.IsNull()) { return &entry; @@ -331,7 +202,6 @@ const xla::Computation* XlaContext::LookupOrCreate( } auto new_entry = create(); { - mutex_lock l(mu_); // Somebody else might have made one concurrently. auto& entry = (*out)[type]; if (entry.IsNull()) { @@ -341,4 +211,4 @@ const xla::Computation* XlaContext::LookupOrCreate( } } -} // end namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 8ece3d37984..3978baaf637 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -13,178 +13,109 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the contexts used to represent XLA JIT computatations. +// This file defines the contexts used during XLA compilation. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ #include -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" namespace tensorflow { -// A XlaExpression wraps an XLA computation. Each Tensor sent -// along an edge during XLA JIT compilation represents a -// XlaExpression, and the shape of the Tensor matches the shape of -// the subcomputation in the ComputationDataHandle. Each -// expression is either a constant, an unbound parameter, or a -// function of previously-compiled expressions. -class XlaExpression { - public: - XlaExpression(); +class XlaOpKernelContext; - // handle() stores the XLA handle of the computation that the - // expression represents. - void set_handle(const xla::ComputationDataHandle& h); - const xla::ComputationDataHandle& handle() const; - - void set_constant_value(Tensor value); - bool has_constant_value() const { return has_constant_value_; } - const Tensor& constant_value() const { return constant_value_; } - - private: - friend class XlaContext; - - // The XLA handle of the expression's computation. - xla::ComputationDataHandle handle_; - - // If this expression is a constant with a known value, 'constant_value' is a - // host-memory Tensor containing the value. Used to avoid invoking XLA for - // expressions that are trivially constant. - bool has_constant_value_; - Tensor constant_value_; - - TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression); -}; - -// The XlaContext is the data structure accessible from -// OpKernelContexts when evaluating a subgraph of Ops for JIT -// compilation by XLA. When an Op is executed during JIT -// compilation the input Tensors to the Op store handles to -// subcomputations compiled by earlier Ops in the subgraph. The Op can -// retrieve these subcomputations by calling either -// GetExpressionFromTensor, which returns the XlaExpression holding -// the subcomputation; or EvaluateAsConstant which returns an XLA -// literal of the result of the subcomputation or an error status if -// the subcomputation depends on unbound parameters. The Op may then -// use the ComputationBuilder available from XlaContext::builder() -// to compile one or more functions of the inputs into -// ComputationDataHandles. The handles can be stored as new -// expressions corresponding to the outputs of the Op by calling -// CreateOutputTensorFromComputation or -// CreateConstantOutputTensor. The *only* correct way to allocate an -// output tensor is using one of the preceding two methods, since they -// ensure there is a valid XlaExpression backing the output -// tensor. No Op should ever call allocate_output or allocate_temp -// directly on the OpKernelContext. It is permissible to pass a tensor -// from an Op input to an output (e.g. call ctx->set_output with a -// tensor passed as an input). As an example, the softmax Op produces -// output from input as follows: -// -// XlaContext& tc = XlaContext::Get(context); -// xla::ComputationBuilder& b = tc.builder(); -// xla::ComputationDataHandle logits = -// tc.GetComputationFromTensor(logits_in)); -// ... The softmax computation uses the builder b to compute a -// xla::ComputationDataHandle softmax holding the desired output. -// ... -// OP_REQUIRES_OK(context, tc.CreateOutputTensorFromComputation( -// context, 0, logits_in.shape().dim_sizes(), -// softmax)); -// +// The XlaContext is the data structure that holds the state of an XLA +// compilation, that is accessible from OpKernelContexts when compiling a +// subgraph of Ops using XLA. class XlaContext : public ResourceBase { public: - // If a retval can be evaluated at JIT time it is returned as a - // Literal in a ConstRetVal struct as part of the ComputationResult. - // TODO(misard) reconcile this with the duplicate data structure in - // the XlaCompilationCache class. - struct ConstRetVal { - // The index of the RetVal corresponding to this constant literal. - int index; - // If status is not OK, value's data is undefined. - Status status; - // The value of the RetVal evaluated at JIT compilation - // time. value.shape() always gives the correct shape of the - // RetVal. If !status.ok() then value's data is undefined, otherwise the - // Tensor buffer is allocated in CPU memory. - Tensor value; + // A struct that represents either a compile-time constant, or an XLA + // computation handle. Used to represent arguments and return values. + struct HandleOrConstant { + // Is this a compile-time constant? If so, what is its value? + bool is_constant; + Tensor constant_value; // Must be in host memory. + + // If this is not a constant, a computation handle. Since the mapping from + // Tensorflow types to XLA types is not necessarily injective (one-to-one), + // we also require the Tensorflow type. + DataType type; + xla::ComputationDataHandle handle; }; + struct Argument { + // Descriptive name for the variable, for use in error messages. + string name; + + // Is this a variable? + bool is_variable = false; + + HandleOrConstant value; + + int64 tensor_array_size = -1; + }; + + // Retrieves the XlaContext of the current compilation. + static XlaContext& Get(const OpKernelContext* ctx); + static XlaContext& Get(const XlaOpKernelContext* ctx); + + // Creates a new XlaContext. + XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, + bool allow_cpu_custom_calls, bool resolve_compile_time_constants); // Virtual method defined by ResourceBase. string DebugString() override; - // Retrieve the XlaContext corresponding to a step's JIT compilation. - static XlaContext& Get(const OpKernelContext* ctx); - static XlaContext& Get(const XlaOpKernelContext* ctx) { - return Get(ctx->op_kernel_context()); - } - - // Create a new XlaContext. - XlaContext(XlaCompiler* compiler, xla::Client* client, - const string& computation_name, bool allow_cpu_custom_calls, - bool resolve_compile_time_constants); - - // Builds XLA computations for each of the arguments. - // Should only be called once to initialize the arguments. Not thread-safe. - Status BuildArguments(std::vector arguments, - bool use_tuple_arg) TF_MUST_USE_RESULT; - - // Returns the results of the symbolic computation that have accumulated in - // the XlaContext. After CollectResults() is called, the context is left in - // an invalid state and must not be reused. - // Sets `requires_runtime_context` if the emitted computation requires a - // runtime context argument. `compile_time_constants` describes any non - // data-dependent results of the computation. `num_nonconst_ouputs` is set to - // the number of outputs of the `computation`. - Status CollectResults(xla::Computation* computation, - bool* requires_runtime_context, - std::vector* compile_time_constants, - int* num_nonconst_outputs); - - // This is called by the Retval Op to associate a computed value - // with a specific return value of the subgraph. - void AddRetval(int retval_index, const xla::ComputationDataHandle& handle); - - // As for Retval, but for return values that are compile-time constants. - Status AddConstRetval(int retval_index, DataType dtype, - const xla::Literal& literal); - - // Mark the computation as having side effects (i.e., Send operators). - void AddSideEffects(); - - // Retrieves the ComputationDataHandle from an input Tensor to an Op. This - // computation was constructed by an Op that executed previously and - // created the output Tensor using CreateOutputTensorFromComputation - // or CreateConstantOutputTensor. - static const xla::ComputationDataHandle& GetComputationFromTensor( - const Tensor& tensor); - XlaCompiler* compiler() const { return compiler_; } // Returns the ComputationBuilder that Ops use for compiling new // expressions. - xla::ComputationBuilder& builder(); + xla::ComputationBuilder* builder(); - const std::vector& args() const { return args_; } - xla::ComputationDataHandle parameter(int num) { return parameters_[num]; } + bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } + bool has_context_parameter() const { return has_context_parameter_; } + + const std::vector& args() const { return args_; } + void set_args(std::vector args); // Get the runtime context parameter, adding one if it does not already exist. // Dies if not compiling a local executable. const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter(); - bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } + const std::vector& retvals() { return retvals_; } + + // This is called by the Retval Op to associate a computed value + // with a specific return value of the subgraph. + void AddRetval(int retval_index, DataType type, + const xla::ComputationDataHandle& handle); + + // As for Retval, but for return values that are compile-time constants. + Status AddConstRetval(int retval_index, DataType dtype, + const xla::Literal& literal); + + // Mark the computation as having side effects (e.g., Send operators). + void AddSideEffects(); + + bool has_side_effects() const { return has_side_effects_; } + + // Creates a variable with variable `variable_id` and initial type `type` and + // value `handle`. `name` is a descriptive name for use in error messages. + // Fails if the variable already exists. + Status CreateVariable(int arg_num, string name, DataType type, + const xla::ComputationDataHandle& handle, + XlaVariable** variable); + + const std::vector>& variables() { + return variables_; + } // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -205,41 +136,11 @@ class XlaContext : public ResourceBase { static const char kXlaContextResourceName[]; private: - friend class XlaOpKernelContext; - - // This method is used to retrieve an expression that was allocated by - // a previous Op. - static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); - - // This method is used to retrieve an uninitialized expression from a - // newly-allocated tensor. - static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor); - - // Retrieves the expression from an input Tensor to an Op. This - // expression was constructed by an Op that executed previously and - // created the output Tensor using CreateOutputTensorFromComputation - // or CreateConstantOutputTensor. - static const XlaExpression* GetExpressionFromTensor(const Tensor& tensor); - XlaCompiler* const compiler_; - mutable mutex mu_; - // The ComputationBuilder used to construct the subgraph's compiled // representation. - xla::ComputationBuilder xla_builder_ GUARDED_BY(mu_); - - // Number of XLA Parameters, not counting the context parameter, if any. - int num_parameters_; - - // Arguments to the JIT compilation, both compile-time constant arguments and - // runtime parameters. - std::vector args_; - bool use_tuple_arg_ = false; - - // Runtime parameters to the XLA computation. Does not include - // compile-time constant arguments. - std::vector parameters_; + xla::ComputationBuilder* builder_; // Allow ops to emit CustomCall operations for CPU. const bool allow_cpu_custom_calls_; @@ -252,18 +153,21 @@ class XlaContext : public ResourceBase { // for an additional final parameter to the computation, through which will be // passed a XlaLocalRuntimeContext* at runtime. Created on demand by // GetOrCreateRuntimeContextParameter(). - bool has_context_parameter_ GUARDED_BY(mu_) = false; - xla::ComputationDataHandle context_parameter_ GUARDED_BY(mu_); + bool has_context_parameter_ = false; + xla::ComputationDataHandle context_parameter_; - // The data-dependent return values of the computation. - std::vector> retval_ - GUARDED_BY(mu_); + // Arguments to the Tensorflow graph, indexed by _Arg index. + // Includes both compile-time constant arguments and runtime parameters. + std::vector args_; - // The non-data-dependent return values of the computation. - std::vector compile_time_constant_ GUARDED_BY(mu_); + // Return values of the Tensorflow graph, indexed by _Retval index. + std::vector retvals_; // Does the computation have side effects, i.e., Send() calls? - bool has_side_effects_ GUARDED_BY(mu_) = false; + bool has_side_effects_ = false; + + // Holds ownership of variables. The variables are not ordered. + std::vector> variables_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; @@ -273,16 +177,16 @@ class XlaContext : public ResourceBase { // map. The returned value != nullptr and is owned by the map. const xla::Computation* LookupOrCreate( DataType type, ComputationMap* out, - const std::function& create) LOCKS_EXCLUDED(mu_); + const std::function& create); // Cached computation to compute Max of two elements, specialized by type. - ComputationMap max_func_ GUARDED_BY(mu_); + ComputationMap max_func_; // Cached computation to compute Sum of two elements, specialized by type. - ComputationMap add_func_ GUARDED_BY(mu_); + ComputationMap add_func_; // Cached computation to compute Sigmoid of an element, specialized by type. - ComputationMap sigmoid_func_ GUARDED_BY(mu_); + ComputationMap sigmoid_func_; TF_DISALLOW_COPY_AND_ASSIGN(XlaContext); }; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index efb0facf7b8..f060f8f2f17 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -89,7 +90,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::F16: - LOG(FATAL) << "f16 literals not yet implemented"; + literal = + *xla::LiteralUtil::CreateR0(static_cast(value)); + break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; case xla::OPAQUE: @@ -107,6 +110,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { + case xla::F16: + return b->ConstantR0(static_cast(value)); + break; case xla::F32: return b->ConstantR0(static_cast(value)); break; @@ -139,4 +145,64 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, return Status::OK(); } +template +static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { + Tensor linspace(DataTypeToEnum::v(), shape); + auto linspace_flat = linspace.flat(); + for (int64 i = 0; i < depth; ++i) { + linspace_flat(i) = i; + } + return linspace; +} + +Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, + int axis, DataType index_type, + const TensorShape& indices_shape, + const xla::ComputationDataHandle& indices, + const xla::ComputationDataHandle& on_value, + const xla::ComputationDataHandle& off_value, + xla::ComputationDataHandle* one_hot) { + const int indices_dims = indices_shape.dims(); + const int output_dims = indices_dims + 1; + + TensorShape output_shape = indices_shape; + output_shape.InsertDim(axis, depth); + + // Build a Tensor populated with values 0, 1, 2, ... depth. + std::vector linspace_dims(output_dims, 1); + linspace_dims[axis] = depth; + TensorShape linspace_shape(linspace_dims); + Tensor linspace; + switch (index_type) { + case DT_UINT8: + linspace = MakeLinspaceTensor(linspace_shape, depth); + break; + case DT_INT32: + linspace = MakeLinspaceTensor(linspace_shape, depth); + break; + case DT_INT64: + linspace = MakeLinspaceTensor(linspace_shape, depth); + break; + default: + return errors::InvalidArgument("Invalid argument type ", + DataTypeString(index_type)); + } + xla::Literal linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + + // Broadcast the linspace constant across the indices along the new axis, + // and test equality at each position. + std::vector broadcast_dims(indices_shape.dims()); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + xla::ComputationDataHandle one_hot_bool = builder->Eq( + indices, builder->ConstantLiteral(linspace_literal), broadcast_dims); + + // Selects the user-provided off_value and on_value values. + *one_hot = builder->Select( + one_hot_bool, builder->Broadcast(on_value, output_shape.dim_sizes()), + builder->Broadcast(off_value, output_shape.dim_sizes())); + return Status::OK(); +} + } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 353ed02edda..a141ee05c13 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -66,6 +66,17 @@ class XlaHelpers { static Status ReshapeLiteral(const xla::Literal& input, gtl::ArraySlice shape, xla::Literal* output); + + // Converts `indices` into a one-hot representation. `depth` is the size + // of the new axis to add. `axis` is the position at which to add the new + // axis. `indices_shape` is the shape of `indices`. `on_value` and `off_value` + // represent the values to use for the on and off positions, respectively. + static Status OneHot(xla::ComputationBuilder* builder, int64 depth, int axis, + DataType index_type, const TensorShape& indices_shape, + const xla::ComputationDataHandle& indices, + const xla::ComputationDataHandle& on_value, + const xla::ComputationDataHandle& off_value, + xla::ComputationDataHandle* one_hot); }; } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h index cd773d64ed4..dca420d6ee3 100644 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h @@ -23,7 +23,7 @@ limitations under the License. // actually used. E.g. some ahead-of-time compiled computations don't need a // thread pool. namespace Eigen { -class ThreadPoolDevice; +struct ThreadPoolDevice; } namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 00cf1adc119..3272b1efa15 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -31,11 +31,38 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { } xla::ComputationBuilder* XlaOpKernelContext::builder() const { - return &XlaContext::Get(this).builder(); + return XlaContext::Get(this).builder(); +} + +// Retrieves an XlaExpression that was allocated by a previous Op. +static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { + const XlaExpression* expression = + reinterpret_cast(tensor.tensor_data().data()); + CHECK(expression->handle().handle() != 0 || + expression->variable() != nullptr); + VLOG(1) << "Fetched T" << expression->handle().handle(); + return expression; +} + +// Retrieves an uninitialized XlaExpression from a newly-allocated tensor. +static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { + const XlaExpression* expression = + reinterpret_cast(tensor->tensor_data().data()); + CHECK_EQ(expression->handle().handle(), 0); + return const_cast(expression); +} + +// Retrieves the ComputationDataHandle from an input Tensor to an Op. This +// computation was constructed by an Op that executed previously and +// created the output Tensor using CreateOutputTensorFromComputation +// or CreateConstantOutputTensor. +static const xla::ComputationDataHandle& GetComputationFromTensor( + const Tensor& tensor) { + return CastExpressionFromTensor(tensor)->handle(); } const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) { - return XlaContext::GetComputationFromTensor(context_->input(index)); + return GetComputationFromTensor(context_->input(index)); } TensorShape XlaOpKernelContext::InputShape(int index) { @@ -60,8 +87,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( " but was asked to be reshaped to incompatible shape ", new_shape.DebugString()); } - const XlaExpression* expression = - XlaContext::CastExpressionFromTensor(tensor); + const XlaExpression* expression = CastExpressionFromTensor(tensor); // If the tensor has a known constant value, there is no need to invoke XLA. if (expression->has_constant_value()) { @@ -112,6 +138,27 @@ Status XlaOpKernelContext::ConstantInputReshaped( return Status::OK(); } +// Converts an int32 or int64 scalar literal to an int64. +static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { + if (xla::ShapeUtil::Rank(literal.shape()) != 0) { + return errors::InvalidArgument("value is not a scalar"); + } + if (literal.shape().element_type() == xla::S32) { + *out = xla::LiteralUtil::Get(literal, {}); + } else if (literal.shape().element_type() == xla::S64) { + *out = xla::LiteralUtil::Get(literal, {}); + } else { + return errors::InvalidArgument("value must be either int32 or int64"); + } + return Status::OK(); +} + +Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + return LiteralToInt64Scalar(literal, out); +} + // Converts an int32 or int64 1D literal to an int64 vector. static Status LiteralToInt64Vector(const xla::Literal& literal, std::vector* out) { @@ -140,6 +187,31 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, + xla::Literal* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + switch (literal.shape().element_type()) { + case xla::S32: + out->Clear(); + *out->mutable_shape() = literal.shape(); + out->mutable_shape()->set_element_type(xla::S64); + for (int32 x : literal.s32s()) { + out->add_s64s(x); + } + return Status::OK(); + + case xla::S64: + out->Swap(&literal); + return Status::OK(); + + default: + return errors::InvalidArgument( + "Invalid argument to ConstantInputAsInt64Literal: ", + xla::ShapeUtil::HumanString(literal.shape())); + } +} + // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { @@ -159,7 +231,7 @@ Status XlaOpKernelContext::InputList( handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { - handles->push_back(XlaContext::GetComputationFromTensor(input)); + handles->push_back(GetComputationFromTensor(input)); shapes->push_back(input.shape()); } return Status::OK(); @@ -176,6 +248,49 @@ Status XlaOpKernelContext::ConstantInputList( return Status::OK(); } +Status XlaOpKernelContext::ReadVariableInput( + int index, xla::ComputationDataHandle* value) { + const Tensor& tensor = context_->input(index); + const XlaExpression* expression = CastExpressionFromTensor(tensor); + XlaVariable* variable = expression->variable(); + TF_RET_CHECK(variable != nullptr); + if (variable->value.handle() == 0) { + return errors::InvalidArgument("Read of uninitialized variable ", + variable->name); + } + *value = variable->value; + return Status::OK(); +} + +string XlaOpKernelContext::VariableDebugString(int index) { + const Tensor& tensor = context_->input(index); + const XlaExpression* expression = CastExpressionFromTensor(tensor); + XlaVariable* variable = expression->variable(); + if (!variable) { + return ""; + } + return variable->name; +} + +Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, + TensorShape* shape) const { + const Tensor& tensor = context_->input(index); + const XlaExpression* expression = CastExpressionFromTensor(tensor); + XlaVariable* variable = expression->variable(); + TF_RET_CHECK(variable != nullptr); + if (variable->value.handle() == 0) { + return errors::InvalidArgument("Read of uninitialized variable ", + variable->name); + } + *type = variable->type; + auto shape_or_status = builder()->GetShape(variable->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + *shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + return Status::OK(); +} + void XlaOpKernelContext::SetOutput(int index, const xla::ComputationDataHandle& handle) { // Makes the host Tensor that will refer to the expression. @@ -196,8 +311,7 @@ void XlaOpKernelContext::SetOutput(int index, // The expression is stored in the tensor's data buffer. Fill in the // fields now. - XlaExpression* expression = - XlaContext::CastExpressionFromUninitializedTensor(output); + XlaExpression* expression = CastExpressionFromUninitializedTensor(output); expression->set_handle(handle); } @@ -207,6 +321,7 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { xla::Literal literal; OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal); + CHECK_NE(handle.handle(), 0); // Make the Tensor that will refer to the expression. Tensor* output = nullptr; @@ -217,16 +332,57 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { // The expression is stored in the tensor's data buffer. Fill in the // fields now. - XlaExpression* expression = - XlaContext::CastExpressionFromUninitializedTensor(output); + XlaExpression* expression = CastExpressionFromUninitializedTensor(output); expression->set_handle(handle); expression->set_constant_value(constant); } +void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) { + Tensor* output = nullptr; + // The shape of the output tensor is the shape of the variable resource + // (i.e., a scalar), not the shape of the variable's value. + OP_REQUIRES_OK(context_, + context_->allocate_output(index, TensorShape(), &output)); + XlaExpression* expression = CastExpressionFromUninitializedTensor(output); + expression->set_variable(variable); +} + +Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) { + const XlaExpression* expression = + CastExpressionFromTensor(context_->input(index)); + TF_RET_CHECK(expression->variable() != nullptr); + *variable = expression->variable(); + return Status::OK(); +} + +Status XlaOpKernelContext::AssignVariable( + int index, DataType type, const xla::ComputationDataHandle& handle) { + TF_RET_CHECK(handle.handle() != 0); + SetOpHasSideEffects(); + + const XlaExpression* expression = + CastExpressionFromTensor(context_->input(index)); + XlaVariable* variable = expression->variable(); + TF_RET_CHECK(variable != nullptr); + if (!((variable->type == DT_INVALID && type != DT_INVALID) || + (variable->type == type))) { + return errors::InvalidArgument( + "Types of variables cannot change after initialization: old type was ", + DataTypeString(variable->type), ", new type is ", DataTypeString(type)); + } + variable->type = type; + variable->value = handle; + return Status::OK(); +} + void XlaOpKernelContext::SetOpHasSideEffects() { XlaContext::Get(context_).AddSideEffects(); } +XlaCompiler* XlaOpKernelContext::compiler() const { + return XlaContext::Get(context_).compiler(); +} + void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); } void XlaOpKernelContext::CtxFailureWithWarning(Status s) { context_->CtxFailureWithWarning(s); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 5fbc0cb6ac3..a25774c3a6a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -45,9 +46,14 @@ class XlaOpKernel : public OpKernel { // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for // implementing operators that perform symbolic execution as part of the XLA // compiler. The key difference is that XlaOpKernelContext produces and consumes -// data as XLA computations, rather than as standard Tensors. (Under the hood, -// symbolic execution communicates using special Tensors, but that is an -// implementation detail that this class hides.) +// data as XLA computations, rather than as standard Tensors. +// +// Under the hood, symbolic execution communicates using special Tensors that +// wrap XlaExpression objects, however this is an implementation detail that +// this class hides. The *only* correct way to allocate a Tensor during +// compilation is using the XlaOpKernelContext methods, since they ensure there +// is a valid XlaExpression backing the tensor. No Op should ever call +// allocate_output or allocate_temp directly on the underlying OpKernelContext. class XlaOpKernelContext { public: explicit XlaOpKernelContext(OpKernelContext* context); @@ -98,9 +104,15 @@ class XlaOpKernelContext { Status ConstantInputReshaped(int index, gtl::ArraySlice new_shape, xla::Literal* constant_literal); + // Converts a constant 1D int32 or int64 tensor into an int64. + Status ConstantInputAsIntScalar(int index, int64* out); + // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector* out); + // Converts a constant int32 or int64 Tensor into an xla int64 Literal. + Status ConstantInputAsInt64Literal(int index, xla::Literal* out); + // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); @@ -134,6 +146,32 @@ class XlaOpKernelContext { // Mark the op has having side effects (i.e., via Send). void SetOpHasSideEffects(); + // Variables + + // Sets `*type` and `*shape` to the current type and shape of a variable's + // value. + Status GetVariableTypeAndShape(int index, DataType* type, + TensorShape* shape) const; + + // Reads the current value of the resouce variable referred to by input + // 'index'. + Status ReadVariableInput(int index, xla::ComputationDataHandle* value); + + // Assigns the value `handle` to the variable referenced by input + // `variable_index`. Marks the operator as having side effects. + Status AssignVariable(int variable_index, DataType type, + const xla::ComputationDataHandle& handle); + + // Sets '*variable' to the variable associated with input `index`. + Status GetVariableInput(int index, XlaVariable** variable); + + // Sets output 'index' to be a reference to variable 'variable'. Used + // to propagate resource variables through the compilation. + void SetVariableOutput(int index, XlaVariable* variable); + + // Returns a human-readable debug string describing 'variable_index'. + string VariableDebugString(int variable_index); + // Helper routines for the OP_REQUIRES macros void CtxFailure(Status s); void CtxFailureWithWarning(Status s); @@ -151,6 +189,10 @@ class XlaOpKernelContext { // Returns the underlying OpKernelContext. Use rarely. OpKernelContext* op_kernel_context() const { return context_; } + // Returns the XlaCompiler that is performing the compilation. Used for, e.g., + // While to compile nested computations. + XlaCompiler* compiler() const; + // TODO(phawkins): find a better home for these helpers. // Get an XLA lambda to compute Max. This is cached in the diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc new file mode 100644 index 00000000000..1bb0d852899 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -0,0 +1,311 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace tensorflow { + +const char* const DEVICE_CPU_XLA_JIT = "XLA_CPU_JIT"; +const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT"; +const char* const DEVICE_XLA_CPU = "XLA_CPU"; +const char* const DEVICE_XLA_GPU = "XLA_GPU"; + +// Is platform 'id' supported by XLA? +static bool IsPlatformSupported(perftools::gputools::Platform::Id id) { + auto platform = perftools::gputools::MultiPlatformManager::PlatformWithId(id); + if (!platform.ok()) return false; + return xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()).ok(); +} + +XlaOpRegistry::XlaOpRegistry() = default; +XlaOpRegistry::~XlaOpRegistry() = default; + +/* static */ void XlaOpRegistry::RegisterCompilationDevice( + const string& device_name, const DeviceRegistration& registration) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto result = + registry.compilation_devices_.emplace(device_name, registration); + CHECK(result.second || result.first->second.compilation_device_name == + registration.compilation_device_name); +} + +/* static */ void XlaOpRegistry::RegisterBackend( + const string& compilation_device_name, + gtl::ArraySlice supported_types, BackendOpFilter op_filter) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto result = registry.backends_.emplace(compilation_device_name, Backend()); + CHECK(result.second) << "Duplicate XLA backend registration " + << compilation_device_name; + result.first->second.supported_types.insert(supported_types.begin(), + supported_types.end()); + result.first->second.op_filter = op_filter; +} + +/* static */ bool XlaOpRegistry::GetCompilationDevice( + const string& device_name, const DeviceRegistration** registration) { + XlaOpRegistry& registry = Instance(); + + // Lazily register the CPU and GPU JIT devices the first time + // GetCompilationDevice is called. + static void* registration_init = [®istry]() { + mutex_lock lock(registry.mutex_); + if (IsPlatformSupported(perftools::gputools::host::kHostPlatformId)) { + DeviceRegistration& registration = + registry.compilation_devices_[DEVICE_CPU]; + registration.compilation_device_name = DEVICE_CPU_XLA_JIT; + registration.requires_compilation = false; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = false; + } + if (IsPlatformSupported(perftools::gputools::cuda::kCudaPlatformId)) { + DeviceRegistration& registration = + registry.compilation_devices_[DEVICE_GPU]; + registration.compilation_device_name = DEVICE_GPU_XLA_JIT; + registration.requires_compilation = false; + registration.enable_jit_by_default = true; + registration.compile_resource_ops = false; + } + return nullptr; + }(); + (void)registration_init; + + mutex_lock lock(registry.mutex_); + auto it = registry.compilation_devices_.find(device_name); + if (it == registry.compilation_devices_.end()) return false; + *registration = &it->second; + return true; +} + +void XlaOpRegistry::RegisterCompilationKernels() { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + + if (registry.jit_kernels_registered_) return; + registry.jit_kernels_registered_ = true; + + OpRegistryInterface* op_registry = OpRegistry::Global(); + for (const auto& op : registry.ops_) { + const OpDef* op_def; + TF_CHECK_OK(op_registry->LookUpOpDef(op.first, &op_def)); + + std::unordered_set type_attrs; + for (const OpDef::AttrDef& attr_def : op_def->attr()) { + if (attr_def.type() == "type" || attr_def.type() == "list(type)") { + type_attrs.insert(attr_def.name()); + } + } + + // Checks there are no type constraints referring to unknown attributes. + for (const auto& constraint : op.second->type_constraints) { + if (type_attrs.find(constraint.first) == type_attrs.end()) { + LOG(FATAL) << "Unknown type attribute " << constraint.first + << " in XLA op registration for " << op.first; + } + } + + for (auto& backend : registry.backends_) { + // If the operator has a device whitelist, only register on whitelisted + // devices. + if (op.second->has_device_whitelist && + op.second->device_whitelist.find(backend.first) == + op.second->device_whitelist.end()) { + continue; + } + + std::unique_ptr kdef(new KernelDef); + kdef->set_op(op.second->name); + kdef->set_device_type(backend.first); + + // Constrain each type attribute to the intersection of: + // a) the types supported by the backend, and + // b) the attribute's type constraints. + // TODO(phawkins): it may be necessary to also take the intersection with + // the set of types supported by the OpDef. + for (const string& type_attr : type_attrs) { + KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); + attr_constraint->set_name(type_attr); + auto* allowed_values = + attr_constraint->mutable_allowed_values()->mutable_list(); + + auto it = op.second->type_constraints.find(type_attr); + for (DataType dtype : backend.second.supported_types) { + if (it == op.second->type_constraints.end() || + (it != op.second->type_constraints.end() && + it->second.find(dtype) != it->second.end())) { + allowed_values->add_type(dtype); + } + } + if (op.second->allow_resource_types) { + allowed_values->add_type(DT_RESOURCE); + } + } + if (backend.second.op_filter != nullptr && + !backend.second.op_filter(kdef.get())) { + continue; + } + VLOG(2) << "XLA op registration: device: " << backend.first + << " op: " << op.first; + registry.kernel_registrars_.emplace_back( + new kernel_factory::OpKernelRegistrar( + new KernelDef(*kdef), "XlaJitOp", op.second->factory)); + backend.second.kernel_defs.push_back(std::move(kdef)); + } + } +} + +std::vector XlaOpRegistry::DeviceKernels( + const string& compilation_device_name) { + std::vector kernels; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.backends_.find(compilation_device_name); + CHECK(it != registry.backends_.end()) + << "Unknown backend " << compilation_device_name; + for (const std::unique_ptr& k : it->second.kernel_defs) { + if (!registry.ops_.at(k->op())->compilation_only) { + kernels.push_back(k.get()); + } + } + return kernels; +} + +XlaOpRegistry& XlaOpRegistry::Instance() { + static XlaOpRegistry* r = new XlaOpRegistry; + return *r; +} + +XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { + registration_.reset(new XlaOpRegistry::OpRegistration); + registration_->name = name.ToString(); +} + +XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { + XlaOpRegistrationBuilder registration(name); + return registration; +} + +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( + gtl::ArraySlice devices) { + registration_->has_device_whitelist = true; + for (StringPiece device : devices) { + registration_->device_whitelist.insert(device.ToString()); + } + return *this; +} + +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { + registration_->has_device_whitelist = true; + registration_->device_whitelist.insert(device.ToString()); + return *this; +} + +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompilationOnly() { + registration_->compilation_only = true; + return *this; +} + +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { + registration_->allow_resource_types = true; + return *this; +} + +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( + StringPiece attr_name, DataType allowed) { + std::set& types = + registration_->type_constraints[attr_name.ToString()]; + types.insert(allowed); + return *this; +} + +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( + StringPiece attr_name, gtl::ArraySlice allowed) { + std::set& types = + registration_->type_constraints[attr_name.ToString()]; + for (DataType t : allowed) { + types.insert(t); + } + return *this; +} + +std::unique_ptr XlaOpRegistrationBuilder::Build( + XlaOpRegistry::Factory factory) { + registration_->factory = factory; + return std::move(registration_); +} + +XlaOpRegistrar::XlaOpRegistrar( + std::unique_ptr registration) { + XlaOpRegistry& registry = XlaOpRegistry::Instance(); + mutex_lock lock(registry.mutex_); + auto result = registry.ops_.emplace(registration->name, nullptr); + if (!result.second) { + LOG(FATAL) << "Duplicate XLA op registration " << registration->name; + } + result.first->second = std::move(registration); +} + +XlaBackendRegistrar::XlaBackendRegistrar( + StringPiece name, gtl::ArraySlice types, + XlaOpRegistry::BackendOpFilter op_filter) { + XlaOpRegistry& registry = XlaOpRegistry::Instance(); + registry.RegisterBackend(name.ToString(), types, op_filter); +} + +bool CpuOpFilter(KernelDef* kdef) { + // TODO(b/34339814): implement inverse erf for double types and remove this + // workaround. + if (kdef->op() == "RandomStandardNormal") { + kdef->clear_constraint(); + // Change the type constraint to permit only DTD_FLOAT. + KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); + attr_constraint->set_name("dtype"); + attr_constraint->mutable_allowed_values()->mutable_list()->add_type( + DT_FLOAT); + return true; + } + return true; +} + +REGISTER_XLA_BACKEND(DEVICE_CPU_XLA_JIT, kCpuAllTypes, CpuOpFilter); + +bool GpuOpFilter(KernelDef* kdef) { + // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to + // slow code. + // TODO(b/34969189) The implementation of TruncatedNormal generates illegal + // code on GPU. + if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || + kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { + return false; + } + return true; +} + +REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h new file mode 100644 index 00000000000..9a39cc96754 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -0,0 +1,270 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// Names of the XLA compilation devices. These are not user-visible, and are +// used internally by the Tensorflow/XLA bridge to perform symbolic execution of +// a Tensorflow graph. + +extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT" +extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" + +extern const char* const DEVICE_XLA_CPU; +extern const char* const DEVICE_XLA_GPU; + +constexpr std::array kIntTypes = {{DT_INT32, DT_INT64}}; +constexpr std::array kFloatTypes = {{DT_FLOAT, DT_DOUBLE}}; +constexpr std::array kNumericTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}}; + +constexpr std::array kCpuAllTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; + +constexpr std::array kGpuAllTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; + +// Class that manages registrations of operators and devices for the XLA JIT. +// Not thread-safe. +class XlaOpRegistry { + public: + typedef OpKernel* (*Factory)(OpKernelConstruction*); + + // Describes how to compile operators assigned to a device. + struct DeviceRegistration { + // The name of the an XLA compilation device to use to compile code. + string compilation_device_name; + + // Do operators assigned to this device require compilation? + bool requires_compilation; + + // If !requires_compilation, should we try to JIT operators on this device + // when XLA JIT compilation is enabled globally via the SessionOptions? + // (It is still possible to explicitly mark operators to JIT compile, even + // if enable_jit_by_default is false.) + bool enable_jit_by_default; + + // Enable compilation of operators that use DT_RESOURCE types? + bool compile_resource_ops = false; + }; + + // Registers an XLA backend. `compilation_device_name` is the name of the + // device used for symbolic execution during compilation. `supported_types` + // is the list of non-resource types supported by the device. Each operators + // will be registered for the intersection of the operator's supported types + // and the device's supported types. `backend_op_filter` is a function used + // to exclude or modify operator registrations on the device; it may be + // nullptr, in which case all ops are included. + // `backend_op_filter` should return true if the op should be registered on + // the device; it may optionally modify the KernelDef. + typedef bool (*BackendOpFilter)(KernelDef* kdef); + static void RegisterBackend(const string& compilation_device_name, + gtl::ArraySlice supported_types, + BackendOpFilter op_filter); + + // Registers `device_name` for XLA compilation, using information from + // `registration`. + static void RegisterCompilationDevice(const string& device_name, + const DeviceRegistration& registration); + + // Returns the JIT device name associated with 'device_name', setting + // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they + // are not null. Returns false and leaves the outputs unchanged if no matching + // JIT device is registered. + // '*enable_jit_by_default' is set to true if we should try to JIT using this + // device when the JIT is enabled via the Session OptimizerOptions. + static bool GetCompilationDevice(const string& device_name, + const DeviceRegistration** registration); + + // Registers all JIT kernels on JIT devices, if not already registered. + // Does nothing otherwise. + static void RegisterCompilationKernels(); + + // Returns KernelDefs for compilation ops registered on + // 'compilation_device_name'. + // Does not include kernels registered as CompilationOnly. + static std::vector DeviceKernels( + const string& compilation_device_name); + + private: + friend class XlaBackendRegistrar; + friend class XlaOpRegistrar; + friend class XlaOpRegistrationBuilder; + + static XlaOpRegistry& Instance(); + + XlaOpRegistry(); + ~XlaOpRegistry(); + + mutex mutex_; + + // Describes an XLA backend. + struct Backend { + // Which types are supported by this device? + std::set supported_types; + + // The per-backend operator filter function. See the comment on + // RegisterBackend() for details. + BackendOpFilter op_filter; + + // KernelDefs built by RegisterCompilationKernels() for each op supported + // by the device. + std::vector> kernel_defs; + }; + + // Map from compilation device names to a description of the backend. + std::unordered_map backends_ GUARDED_BY(mutex_); + + // Map from Tensorflow device names to the corresponding JIT device metadata. + std::unordered_map compilation_devices_ + GUARDED_BY(mutex_); + + // A description of a Tensorflow operator that can be compiled to XLA. + struct OpRegistration { + string name; + + // Should this operator be registered only on compilation devices, without a + // dummy kernel registered on the corresponding XLA device? + bool compilation_only = false; + + // Should we allow resource types for type attributes? Used by _Arg to + // allow DT_RESOURCE. + bool allow_resource_types = false; + + // Mapping from attribute name to a list of supported types. + std::unordered_map> type_constraints; + + // An optional whitelist of devices. If there is no whitelist, all devices + // are permitted. + bool has_device_whitelist = false; + std::unordered_set device_whitelist; + + // Factory used to build OpKernels that perform symbolic execution. + Factory factory; + }; + + // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. + std::unordered_map> ops_ + GUARDED_BY(mutex_); + + // Have we already registered the JIT kernels on the JIT devices? + bool jit_kernels_registered_ = false; + + // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel + // registrations created by RegisterCompilationKernels() and + // RegisterDeviceKernels(). + std::vector> + kernel_registrars_ GUARDED_BY(mutex_); +}; + +// REGISTER_XLA_OP() registers an XLA OpKernel by name, for example: +// REGISTER_XLA_OP(Name("Add"), AddOp); +// where 'AddOp' is the name of a JIT OpKernel class that implements "Add". +// +// We don't use a variadic macro here because we don't expect JIT operators to +// be templated. + +#define REGISTER_XLA_OP(NAME, OP) \ + REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) + +class XlaOpRegistrationBuilder { + public: + // Starts an operator registration chain. + static XlaOpRegistrationBuilder Name(StringPiece name); + + // Specifies a whitelist of devices on which the operator may run. + XlaOpRegistrationBuilder& Device(StringPiece devices); + XlaOpRegistrationBuilder& Device(gtl::ArraySlice devices); + + // Specifies a type constraint for a type variable attribute. Each constraint + // specifies the set of types that the type variable may assume. + XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, + DataType allowed); + + XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, + gtl::ArraySlice allowed); + + // Specifies that a dummy copy of this operator should not be registered on + // XLA_* devices, but may be used during compilation. + XlaOpRegistrationBuilder& CompilationOnly(); + + // Allow DT_RESOURCE types for type parameters. + XlaOpRegistrationBuilder& AllowResourceTypes(); + + std::unique_ptr Build( + XlaOpRegistry::Factory factory); + + private: + XlaOpRegistrationBuilder(StringPiece name); + + std::unique_ptr registration_; +}; + +// REGISTER_XLA_BACKEND() registers an XLA backend. Example usage: +// REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter); +#define REGISTER_XLA_BACKEND(NAME, ...) \ + REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__) + +// Implementation details. + +class XlaOpRegistrar { + public: + XlaOpRegistrar(std::unique_ptr registration); +}; + +#define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \ + REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP) + +#define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ + static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ + XlaOpRegistrationBuilder::BUILDER.Build( \ + [](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { return new OP(context); })); + +class XlaBackendRegistrar { + public: + XlaBackendRegistrar(StringPiece name, gtl::ArraySlice types, + XlaOpRegistry::BackendOpFilter op_filter = nullptr); +}; + +#define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \ + REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__) + +#define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \ + static ::tensorflow::XlaBackendRegistrar \ + xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 0f2a46c11d3..2491cc3f7a2 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -6,6 +6,8 @@ package_group( name = "friends", packages = [ "//tensorflow/compiler/...", + "//tensorflow/contrib/tpu/...", + "//tensorflow/contrib/xla_tf_graph/...", ], ) @@ -16,6 +18,7 @@ package_group( ], ) +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") # Filegroup used to collect source files for dependency checking. @@ -43,11 +46,43 @@ xla_proto_library( ], ) +# This is a headers target that extra XLA devices can use to prevent +# circular dependencies. Devices that are compiled as separate shared +# objects can also use it to prevent linking of library code. +cc_header_only_library( + name = "xla_headers_lib", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_evaluator", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:stream_executor_headers_lib", + ], +) + +cc_library( + name = "test", + testonly = 1, + hdrs = ["test.h"], + visibility = [":friends"], + deps = [ + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + ], +) + cc_library( name = "types", hdrs = ["types.h"], visibility = [":friends"], - deps = ["//tensorflow/core:lib"], + deps = [ + "//tensorflow/core:lib", + "//third_party/eigen3", + ], ) cc_library( @@ -80,9 +115,9 @@ cc_test( deps = [ ":status_macros", ":statusor", + ":test", ":test_helpers", "//tensorflow/core:lib", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -115,6 +150,7 @@ cc_test( srcs = ["statusor_test.cc"], deps = [ ":statusor", + ":test", ":types", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -148,18 +184,22 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":status_macros", + ":statusor", ":types", + ":util", "//tensorflow/core:lib", ], ) cc_test( name = "util_test", + size = "small", srcs = ["util_test.cc"], deps = [ + ":test", ":types", ":util", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -195,37 +235,40 @@ cc_library( cc_test( name = "shape_util_test", + size = "small", srcs = ["shape_util_test.cc"], deps = [ ":shape_util", + ":test", ":test_helpers", ":types", ":util", - "//tensorflow/core:test", + ":xla_data_proto", "//tensorflow/core:test_main", ], ) cc_test( name = "layout_util_test", + size = "small", srcs = ["layout_util_test.cc"], deps = [ ":shape_util", + ":test", ":test_helpers", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) cc_test( name = "index_util_test", + size = "small", srcs = ["index_util_test.cc"], deps = [ ":shape_util", - ":test_helpers", + ":test", ":xla_data_proto", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -240,6 +283,7 @@ cc_library( ":array3d", ":array4d", ":shape_util", + ":status_macros", ":types", ":util", ":xla_data_proto", @@ -249,13 +293,14 @@ cc_library( cc_test( name = "literal_util_test", + size = "small", srcs = ["literal_util_test.cc"], deps = [ ":array3d", ":array4d", ":literal_util", ":shape_util", - ":test_helpers", + ":test", ":types", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -270,7 +315,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":util", - ":xla_data_proto", "//tensorflow/core:lib", ], ) @@ -300,10 +344,11 @@ cc_library( cc_test( name = "array2d_test", + size = "small", srcs = ["array2d_test.cc"], deps = [ ":array2d", - "//tensorflow/core:test", + ":test", "//tensorflow/core:test_main", ], ) @@ -320,11 +365,12 @@ cc_library( cc_test( name = "array3d_test", + size = "small", srcs = ["array3d_test.cc"], deps = [ ":array3d", + ":test", ":types", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -342,11 +388,12 @@ cc_library( cc_test( name = "array4d_test", + size = "small", srcs = ["array4d_test.cc"], deps = [ ":array4d", + ":test", "//tensorflow/core:lib", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -358,25 +405,6 @@ cc_library( visibility = ["//visibility:public"], ) -cc_library( - name = "differential_set", - hdrs = ["differential_set.h"], - visibility = [":internal"], - deps = [ - "//tensorflow/core:lib", - ], -) - -cc_test( - name = "differential_set_test", - srcs = ["differential_set_test.cc"], - deps = [ - ":differential_set", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "packed_literal_reader", srcs = ["packed_literal_reader.cc"], @@ -397,7 +425,6 @@ cc_library( cc_library( name = "test_helpers", testonly = 1, - srcs = ["test_helpers.cc"], hdrs = ["test_helpers.h"], visibility = [":internal"], deps = [ @@ -429,15 +456,16 @@ cc_library( cc_test( name = "text_literal_reader_test", + size = "small", srcs = ["text_literal_reader_test.cc"], deps = [ ":literal_util", ":shape_util", + ":test", ":text_literal_reader", ":types", ":xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -459,14 +487,15 @@ cc_library( cc_test( name = "text_literal_writer_test", + size = "small", srcs = ["text_literal_writer_test.cc"], deps = [ ":literal_util", + ":test", ":test_helpers", ":text_literal_writer", ":types", "//tensorflow/core:lib", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -486,12 +515,13 @@ cc_library( cc_test( name = "shape_tree_test", + size = "small", srcs = ["shape_tree_test.cc"], deps = [ ":shape_tree", ":shape_util", + ":test", ":xla_data_proto", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -543,17 +573,18 @@ cc_library( cc_test( name = "reference_util_test", + size = "small", srcs = ["reference_util_test.cc"], deps = [ ":array2d", ":array4d", ":literal_util", ":reference_util", + ":test", ":util", ":xla_data_proto", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index ceed573f1f0..593084a0c11 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -44,12 +44,14 @@ class Array2D { Array2D() : n1_(0), n2_(0) {} // Creates an array of dimensions n1 x n2, uninitialized values. - Array2D(const int64 n1, const int64 n2) : n1_(n1), n2_(n2) { - values_.resize(n1 * n2); + Array2D(const int64 n1, const int64 n2) + : n1_(n1), n2_(n2), values_(new T[n1 * n2]()) { + Fill(T()); } // Creates an array of dimensions n1 x n2, initialized to value. - Array2D(const int64 n1, const int64 n2, const T value) : Array2D(n1, n2) { + Array2D(const int64 n1, const int64 n2, const T value) + : n1_(n1), n2_(n2), values_(new T[n1 * n2]()) { Fill(value); } @@ -67,16 +69,30 @@ class Array2D { } } - T& operator()(const int64 n1, const int64 n2) { - CHECK_LT(n1, n1_); - CHECK_LT(n2, n2_); - return values_[n1 * n2_ + n2]; + Array2D(const Array2D& other) : Array2D(other.n1(), other.n2()) { + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); } - const T& operator()(const int64 n1, const int64 n2) const { - CHECK_LT(n1, n1_); - CHECK_LT(n2, n2_); - return values_[n1 * n2_ + n2]; + Array2D& operator=(const Array2D& other) { + n1_ = other.n1(); + n2_ = other.n2(); + values_.reset(new T[num_elements()]); + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + return *this; + } + + T& operator()(const int64 i1, const int64 i2) { + CHECK_LT(i1, n1_); + CHECK_LT(i2, n2_); + return values_[i1 * n2_ + i2]; + } + + const T& operator()(const int64 i1, const int64 i2) const { + CHECK_LT(i1, n1_); + CHECK_LT(i2, n2_); + return values_[i1 * n2_ + i2]; } // Access to the array's dimensions. height() and width() provide the @@ -86,15 +102,15 @@ class Array2D { int64 n2() const { return n2_; } int64 height() const { return n1_; } int64 width() const { return n2_; } - int64 num_elements() const { return values_.size(); } + int64 num_elements() const { return n1_ * n2_; } // Low-level accessor for stuff like memcmp, handle with care. Returns pointer // to the underlying storage of the array (similarly to std::vector::data()). - T* data() const { return const_cast(this)->values_.data(); } + T* data() const { return const_cast(this)->values_.get(); } // Fills the array with the given value. void Fill(const T& value) { - std::fill(values_.begin(), values_.end(), value); + std::fill(&values_[0], &values_[0] + num_elements(), value); } // Applies f to all cells in this array, in row-major order. @@ -126,8 +142,8 @@ class Array2D { std::mt19937 g(seed); std::normal_distribution distribution(mean, static_cast(value)); - for (auto& v : values_) { - v = static_cast(distribution(g)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = static_cast(distribution(g)); } } @@ -152,7 +168,7 @@ class Array2D { private: int64 n1_; int64 n2_; - std::vector values_; + std::unique_ptr values_; }; // Returns a linspace-populated Array2D in the range [from, to] (inclusive) diff --git a/tensorflow/compiler/xla/array2d_test.cc b/tensorflow/compiler/xla/array2d_test.cc index ac107b1c0d4..795d50ca5b5 100644 --- a/tensorflow/compiler/xla/array2d_test.cc +++ b/tensorflow/compiler/xla/array2d_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/platform/test.h" +#include "tensorflow/compiler/xla/test.h" namespace xla { namespace { @@ -84,6 +84,17 @@ TEST(Array2dTest, IndexingReadWrite) { EXPECT_EQ(arr(1, 2), 61); } +TEST(Array2dTest, IndexingReadWriteBool) { + Array2D arr = {{false, true, false}, {true, true, false}}; + + EXPECT_EQ(arr(1, 1), true); + EXPECT_EQ(arr(1, 2), false); + arr(1, 1) = false; + arr(1, 2) = true; + EXPECT_EQ(arr(1, 1), false); + EXPECT_EQ(arr(1, 2), true); +} + TEST(Array2dTest, Fill) { Array2D fullof7(2, 3, 7); for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index 46bc1a63921..124ccd1975b 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -20,9 +20,9 @@ limitations under the License. #include #include #include +#include #include #include -#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -39,13 +39,13 @@ class Array3D { public: // Creates an array of dimensions n1 x n2 x n3, uninitialized values. Array3D(const int64 n1, const int64 n2, const int64 n3) - : n1_(n1), n2_(n2), n3_(n3) { - values_.resize(n1 * n2 * n3); + : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) { + Fill(T()); } // Creates an array of dimensions n1 x n2 x n3, initialized to value. Array3D(const int64 n1, const int64 n2, const int64 n3, const T value) - : Array3D(n1, n2, n3) { + : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) { Fill(value); } @@ -73,34 +73,50 @@ class Array3D { } } - T& operator()(const int64 n1, const int64 n2, const int64 n3) { - CHECK_LT(n1, n1_); - CHECK_LT(n2, n2_); - CHECK_LT(n3, n3_); - return values_[n1 * n2_ * n3_ + n2 * n3_ + n3]; + Array3D(const Array3D& other) + : Array3D(other.n1(), other.n2(), other.n3()) { + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); } - const T& operator()(const int64 n1, const int64 n2, const int64 n3) const { - CHECK_LT(n1, n1_); - CHECK_LT(n2, n2_); - CHECK_LT(n3, n3_); - return values_[n1 * n2_ * n3_ + n2 * n3_ + n3]; + Array3D& operator=(const Array3D& other) { + n1_ = other.n1(); + n2_ = other.n2(); + n3_ = other.n3(); + values_.reset(new T[num_elements()]); + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + return *this; + } + + T& operator()(const int64 i1, const int64 i2, const int64 i3) { + CHECK_LT(i1, n1_); + CHECK_LT(i2, n2_); + CHECK_LT(i3, n3_); + return values_[i1 * n2_ * n3_ + i2 * n3_ + i3]; + } + + const T& operator()(const int64 i1, const int64 i2, const int64 i3) const { + CHECK_LT(i1, n1_); + CHECK_LT(i2, n2_); + CHECK_LT(i3, n3_); + return values_[i1 * n2_ * n3_ + i2 * n3_ + i3]; } // Access to the array's dimensions. int64 n1() const { return n1_; } int64 n2() const { return n2_; } int64 n3() const { return n3_; } - int64 num_elements() const { return values_.size(); } + int64 num_elements() const { return n1_ * n2_ * n3_; } // Fills the array with the given value. void Fill(const T& value) { - std::fill(values_.begin(), values_.end(), value); + std::fill(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with sequentially increasing values. void FillIota(const T& value) { - std::iota(values_.begin(), values_.end(), value); + std::iota(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with random normal values with a mean of 0 and standard @@ -110,8 +126,8 @@ class Array3D { std::mt19937 g(seed); std::normal_distribution distribution(mean, static_cast(value)); - for (auto& v : values_) { - v = static_cast(distribution(g)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = static_cast(distribution(g)); } } @@ -119,7 +135,7 @@ class Array3D { int64 n1_; int64 n2_; int64 n3_; - std::vector values_; + std::unique_ptr values_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/array3d_test.cc b/tensorflow/compiler/xla/array3d_test.cc index fa4435dfc48..6b5f4b343b2 100644 --- a/tensorflow/compiler/xla/array3d_test.cc +++ b/tensorflow/compiler/xla/array3d_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index db51a57cf26..d93f968f4d7 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -54,13 +55,17 @@ namespace xla { template class Array4D { public: - // Creates a 4D array, unitialized values. + // Creates a 4D array, uninitialized values. Array4D(int64 planes, int64 depth, int64 height, int64 width) - : planes_(planes), depth_(depth), height_(height), width_(width) { - values_.resize(planes * depth * height * width); + : planes_(planes), + depth_(depth), + height_(height), + width_(width), + values_(new T[planes * depth * height * width]) { + Fill(T()); } - // Creates a 4D array, initalized to value. + // Creates a 4D array, initialized to value. Array4D(int64 planes, int64 depth, int64 height, int64 width, T value) : Array4D(planes, depth, height, width) { Fill(value); @@ -107,6 +112,23 @@ class Array4D { } } + Array4D(const Array4D& other) + : Array4D(other.planes(), other.depth(), other.height(), other.width()) { + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + } + + Array4D& operator=(const Array4D& other) { + planes_ = other.planes(); + depth_ = other.depth(); + height_ = other.height(); + width_ = other.width(); + values_.reset(new T[num_elements()]); + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + return *this; + } + T& operator()(int64 plane, int64 depth, int64 height, int64 width) { CHECK_LT(plane, planes_); CHECK_LT(depth, depth_); @@ -131,24 +153,24 @@ class Array4D { int64 n3() const { return height_; } int64 n2() const { return depth_; } int64 n1() const { return planes_; } - int64 num_elements() const { return values_.size(); } + int64 num_elements() const { return width_ * height_ * depth_ * planes_; } // Sets all the values in the array to values. template > void SetValues(const Container& container) { CHECK_EQ(std::distance(std::begin(container), std::end(container)), num_elements()); - values_.assign(std::begin(container), std::end(container)); + std::copy(std::begin(container), std::end(container), &values_[0]); } // Fills the array with the given value. void Fill(const T& value) { - std::fill(values_.begin(), values_.end(), value); + std::fill(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with iota. void FillIota(const T& value) { - std::iota(values_.begin(), values_.end(), value); + std::iota(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with random variable with a deviation of value and a mean @@ -158,8 +180,8 @@ class Array4D { std::mt19937 g(seed); std::normal_distribution distribution(mean, static_cast(value)); - for (auto& v : values_) { - v = static_cast(distribution(g)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = static_cast(distribution(g)); } } @@ -264,7 +286,7 @@ class Array4D { int64 depth_; int64 height_; int64 width_; - std::vector values_; + std::unique_ptr values_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc index 72ada467e51..3bc8148c911 100644 --- a/tensorflow/compiler/xla/array4d_test.cc +++ b/tensorflow/compiler/xla/array4d_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 3e9dfe2a922..63c6d9ddaca 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -46,6 +46,7 @@ cc_library( cc_test( name = "padding_test", + size = "small", srcs = ["padding_test.cc"], deps = [ ":padding", @@ -99,6 +100,26 @@ cc_library( ], ) +cc_library( + name = "compile_only_client", + srcs = ["compile_only_client.cc"], + hdrs = ["compile_only_client.h"], + deps = [ + ":client", + ":computation", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:compile_only_service", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:support", + ], +) + # This target is used to instantiate the XLA service in-process and create # a client for it. cc_library( @@ -106,12 +127,14 @@ cc_library( srcs = ["client_library.cc"], hdrs = ["client_library.h"], deps = [ + ":compile_only_client", ":local_client", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 341c02f1a1f..454d0fbd965 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -58,34 +58,13 @@ StatusOr> Client::Transfer( "server provided response without a literal in " "TransferToClient request"); } - - return WrapUnique(response.release_literal()); -} - -Status Client::TransferInProcess(const GlobalData& data, void* destination) { - TransferToClientInProcessRequest request; - *request.mutable_data() = data.handle(); - request.set_buffer(reinterpret_cast(destination)); - TransferToClientInProcessResponse response; - - VLOG(1) << "making transfer in-process request"; - VLOG(3) << "TransferToClientInProcessRequest: {" << request.DebugString() - << "}"; - Status s = stub_->TransferToClientInProcess(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToClientInProcessResponse: {" << response.DebugString() - << "}"; - return Status::OK(); + return MakeUnique(response.literal()); } StatusOr> Client::TransferToServer( const Literal& literal, const DeviceHandle* device_handle) { TransferToServerRequest request; - *request.mutable_literal() = literal; + *request.mutable_literal() = literal.ToProto(); if (device_handle) { *request.mutable_device_handle() = *device_handle; } @@ -113,7 +92,7 @@ StatusOr> Client::TransferToServer( Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, const DeviceHandle* device_handle) { TransferToInfeedRequest request; - *request.mutable_literal() = literal; + *request.mutable_literal() = literal.ToProto(); if (device_handle) { *request.mutable_device_handle() = *device_handle; } @@ -132,6 +111,39 @@ Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, return Status::OK(); } +StatusOr> Client::TransferFromOutfeed( + const Shape* shape_with_layout, int64 replica_id, + const DeviceHandle* device_handle) { + TransferFromOutfeedRequest request; + if (device_handle) { + *request.mutable_device_handle() = *device_handle; + } + request.set_replica_id(replica_id); + if (shape_with_layout != nullptr) { + *request.mutable_shape_with_layout() = *shape_with_layout; + } + TransferFromOutfeedResponse response; + + VLOG(1) << "making transfer from outfeed request"; + VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}"; + Status s = stub_->TransferFromOutfeed(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}"; + + if (!response.has_literal()) { + return FailedPrecondition( + "server provided response without a literal in " + "TransferToClient request"); + } + + Literal literal(response.literal()); + return MakeUnique(literal); +} + Status Client::ResetDevice() { ResetDeviceRequest request; ResetDeviceResponse response; @@ -164,34 +176,6 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } -StatusOr> Client::TransferToServerInProcess( - const Shape& shape, const void* buffer) { - TransferToServerInProcessRequest request; - request.set_buffer(reinterpret_cast(buffer)); - *request.mutable_shape() = shape; - TransferToServerInProcessResponse response; - - VLOG(1) << "making transfer to server in-process request"; - VLOG(3) << "TransferToServerInProcessRequest: {" << request.DebugString() - << "}"; - Status s = stub_->TransferToServerInProcess(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - VLOG(3) << "TransferToServerInProcessResponse: {" << response.DebugString() - << "}"; - - if (!response.has_data()) { - return FailedPrecondition( - "server provided response without a data handle in " - "TransferToServerInProcess request"); - } - - return MakeUnique(stub_, response.data()); -} - StatusOr Client::LoadSnapshot(const SessionModule& module) { LoadComputationSnapshotRequest request; *request.mutable_module() = module; @@ -269,7 +253,7 @@ StatusOr>> Client::ExecuteParallel( } std::vector> outputs; - for (int64 i = 0; i < computations.size(); ++i) { + for (size_t i = 0; i < computations.size(); ++i) { outputs.push_back( MakeUnique(stub_, response.responses(i).output())); if (computations[i].execution_profile != nullptr) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index f261de9d0d1..797835160fa 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -119,6 +120,15 @@ class Client { Status TransferToInfeed(const Literal& literal, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); + // Transfers from the Outfeed of the device. + // + // device_handle and replica_id together specify a particular device; a device + // assigned for the given replica_id among the replicas that the given device + // handle belongs to. + StatusOr> TransferFromOutfeed( + const Shape* shape_with_layout, int64 replica_id = 0, + const DeviceHandle* device_handle = nullptr); + // Resets the device, clearing all existing state on the device. Status ResetDevice(); @@ -143,8 +153,7 @@ class Client { const Computation& computation) const; // Returns the Shape of the given array specified by 'data'. The shape - // includes the Layout of the array as it is stored on the service. The layout - // information is useful for calling TransferInProcess. + // includes the Layout of the array as it is stored on the service. StatusOr GetShape(const GlobalData& data); // As above, but returns the shape of the provided computation (parameter @@ -156,24 +165,6 @@ class Client { // two computations via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); - // If the service is running in the same process as the client then the - // following "InProcess" transfer methods may be used. These methods enable - // more efficient transfer of arrays to and from the service. - - // Transfer array from the service into the given buffer. The buffer must be - // large enough to hold the array. The array is copied verbatim (memcpy) from - // the service. The method GetShape should be called ahead of time - // to get the shape and layout of the array as it is stored in the - // service. The shape and layout can be used to determine how large the buffer - // needs to be. - Status TransferInProcess(const GlobalData& data, void* destination); - - // Transfer array to the service from the given buffer with the given shape - // and layout. The service creates an internal copy of the data so the client - // can free the buffer when this method returns. - StatusOr> TransferToServerInProcess( - const Shape& shape, const void* buffer); - StatusOr LoadSnapshot(const SessionModule& module); ServiceInterface* stub() { return stub_; } diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 93437023bc8..8238261e1c9 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -43,6 +43,16 @@ int LocalClientOptions::number_of_replicas() const { return number_of_replicas_; } +LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int LocalClientOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + /* static */ ClientLibrary& ClientLibrary::Singleton() { static ClientLibrary* c = new ClientLibrary; return *c; @@ -69,22 +79,24 @@ ClientLibrary::~ClientLibrary() = default; TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - auto it = client_library.instances_.find(platform->id()); - if (it != client_library.instances_.end()) { + auto it = client_library.local_instances_.find(platform->id()); + if (it != client_library.local_instances_.end()) { return it->second->client.get(); } ServiceOptions service_options; service_options.set_platform(platform); service_options.set_number_of_replicas(replica_count); + service_options.set_intra_op_parallelism_threads( + options.intra_op_parallelism_threads()); - std::unique_ptr instance = MakeUnique(); + auto instance = MakeUnique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); instance->client = MakeUnique(instance->service.get()); LocalClient* cl = instance->client.get(); - client_library.instances_.insert( + client_library.local_instances_.insert( std::make_pair(platform->id(), std::move(instance))); return cl; } @@ -99,9 +111,35 @@ ClientLibrary::~ClientLibrary() = default; perftools::gputools::Platform* platform) { ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); - auto it = client_library.instances_.find(platform->id()); - CHECK(it != client_library.instances_.end()); + auto it = client_library.local_instances_.find(platform->id()); + CHECK(it != client_library.local_instances_.end()); return it->second->service.get(); } +/* static */ StatusOr +ClientLibrary::GetOrCreateCompileOnlyClient( + perftools::gputools::Platform* platform) { + ClientLibrary& client_library = Singleton(); + tensorflow::mutex_lock lock(client_library.service_mutex_); + + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + auto it = client_library.compile_only_instances_.find(platform->id()); + if (it != client_library.compile_only_instances_.end()) { + return it->second->client.get(); + } + + auto instance = MakeUnique(); + TF_ASSIGN_OR_RETURN(instance->service, + CompileOnlyService::NewService(platform)); + instance->client = MakeUnique(instance->service.get()); + CompileOnlyClient* cl = instance->client.get(); + + client_library.compile_only_instances_.insert( + std::make_pair(platform->id(), std::move(instance))); + return cl; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 2bc319f9333..3ddd235d0ef 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -26,7 +26,9 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -51,9 +53,14 @@ class LocalClientOptions { LocalClientOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; + // Sets the thread pool size for parallel execution of an individual operator. + LocalClientOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + private: perftools::gputools::Platform* platform_ = nullptr; int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; }; class ClientLibrary { @@ -76,6 +83,13 @@ class ClientLibrary { // access user computations from client. static LocalService* GetXlaService(perftools::gputools::Platform* platform); + // Singleton constructor-or-accessor for compile-only clients. Arguments: + // + // platform : The platform the underlying XLA service should target. If + // null then default platform is used. + static StatusOr GetOrCreateCompileOnlyClient( + perftools::gputools::Platform* platform = nullptr); + private: // Returns the singleton instance of ClientLibrary. static ClientLibrary& Singleton(); @@ -90,10 +104,21 @@ class ClientLibrary { std::unique_ptr client; }; + struct CompileOnlyInstance { + // Service that is wrapped by the singleton client object. + std::unique_ptr service; + // Singleton client object. + std::unique_ptr client; + }; + tensorflow::mutex service_mutex_; // Guards the singleton creation state. std::unordered_map> - instances_ GUARDED_BY(service_mutex_); + local_instances_ GUARDED_BY(service_mutex_); + + std::unordered_map> + compile_only_instances_ GUARDED_BY(service_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary); }; diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc new file mode 100644 index 00000000000..d9972ef77b9 --- /dev/null +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/compile_only_client.h" + +#include "external/llvm/include/llvm/ADT/Triple.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +StatusOr>> +CompileOnlyClient::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options) { + std::vector service_instances; + service_instances.reserve(computations.size()); + for (const AotComputationInstance& instance : computations) { + service_instances.push_back({}); + CompileOnlyService::AotComputationInstance& service_instance = + service_instances.back(); + TF_RET_CHECK(instance.computation != nullptr); + service_instance.computation = instance.computation->handle(); + service_instance.argument_layouts = instance.argument_layouts; + service_instance.result_layout = instance.result_layout; + } + return compiler_service_->CompileAheadOfTime(service_instances, options); +} + +int64 CompileOnlyClient::PointerSizeForTriple( + tensorflow::StringPiece target_triple) { + llvm::Triple triple( + llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple))); + if (triple.isArch64Bit()) { + return 8; + } else if (triple.isArch32Bit()) { + return 4; + } else { + CHECK(triple.isArch16Bit()); + return 2; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h new file mode 100644 index 00000000000..59000487113 --- /dev/null +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/service/compile_only_service.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// An XLA Client specialization for doing ahead-of-time compilation. This does +// not require (or attempt to instantiate) an execution-capable backend for the +// relevant platform. +class CompileOnlyClient : public Client { + public: + explicit CompileOnlyClient(CompileOnlyService* service) + : Client(service), compiler_service_(service) {} + + CompileOnlyClient(const CompileOnlyClient&) = delete; + void operator=(const CompileOnlyClient&) = delete; + + // A description of a computation to compile using CompileAheadOfTime. + struct AotComputationInstance { + const Computation* computation; + // Inform the compiler of the expected layout for arguments. + std::vector argument_layouts; + // Specifies the expected result layout. + const Shape* result_layout; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. The |options| parameter describes + // the target for which the compiler should emit code. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options); + + // Returns the size of a pointer in bytes for a given triple. + static int64 PointerSizeForTriple(tensorflow::StringPiece triple); + + private: + CompileOnlyService* compiler_service_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc index cd7d8df58b8..4baea8df6e3 100644 --- a/tensorflow/compiler/xla/client/computation.cc +++ b/tensorflow/compiler/xla/client/computation.cc @@ -28,12 +28,12 @@ Computation::Computation(ServiceInterface* parent, : handle_(handle), parent_(parent) {} Computation::Computation(Computation&& computation) - : handle_(computation.handle_), parent_(computation.parent_) { + : handle_(std::move(computation.handle_)), parent_(computation.parent_) { computation.ResetWithoutFreeing(); } void Computation::Reset() { - // TODO(leary) deallocate any owned computation. + // TODO(b/34469253) deallocate any owned computation. ResetWithoutFreeing(); } diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 73f450e1f2e..37bf697683b 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -106,7 +106,7 @@ bool ComputationBuilder::MakeWindow( tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, Window* window) { - const auto verify_size = [&](const int64 x, const char* x_name) { + const auto verify_size = [&](const size_t x, const char* x_name) { if (x == 0 || x == window_dimensions.size()) { return true; } else { @@ -165,12 +165,14 @@ ComputationDataHandle ComputationBuilder::ConstantOp( } ConstantRequest request; - Literal* literal = request.mutable_literal(); - populate(literal); - VLOG(3) << "created constant: " << literal->ShortDebugString(); + Literal literal; + populate(&literal); + *request.mutable_literal() = literal.ToProto(); + VLOG(3) << "created constant: " << request.literal().ShortDebugString(); OpRequest op_request; *op_request.mutable_constant_request() = request; *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making constant request"; @@ -198,6 +200,7 @@ ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_parameter_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making parameter request"; @@ -253,7 +256,8 @@ void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs, ComputationDataHandle ComputationBuilder::Slice( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) { + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice stride) { if (!first_error_.ok() || !PrepareComputation().ok()) { return ComputationDataHandle(); } @@ -266,9 +270,13 @@ ComputationDataHandle ComputationBuilder::Slice( for (int64 index : limit_indices) { request.add_limit_indices(index); } + for (int64 index : stride) { + request.add_stride(index); + } OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_slice_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making slice request"; @@ -293,6 +301,7 @@ ComputationDataHandle ComputationBuilder::DynamicSlice( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_dynamic_slice_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making dynamic slice request"; @@ -314,6 +323,7 @@ ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_dynamic_update_slice_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making dynamic update slice request"; @@ -336,6 +346,7 @@ ComputationDataHandle ComputationBuilder::ConcatInDim( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_concatenate_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making concatenate request"; @@ -358,6 +369,7 @@ ComputationDataHandle ComputationBuilder::Broadcast( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_broadcast_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making broadcast request"; @@ -380,6 +392,7 @@ ComputationDataHandle ComputationBuilder::Pad( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_pad_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making pad request"; @@ -406,6 +419,7 @@ ComputationDataHandle ComputationBuilder::Reshape( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_reshape_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making reshape request"; @@ -440,7 +454,8 @@ ComputationDataHandle ComputationBuilder::Collapse( // Don't support out-of-order collapse here. // Checks that the collapsed dimensions are in order and consecutive. - for (int i = 1; i < dims_to_collapse.size(); ++i) { + for (tensorflow::gtl::ArraySlice::size_type i = 1; + i < dims_to_collapse.size(); ++i) { if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) { NoteError(InvalidArgument( "Collapsed dimensions are not in order and consecutive.")); @@ -482,6 +497,7 @@ void ComputationBuilder::Trace(const string& tag, OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_trace_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making trace request"; @@ -513,6 +529,7 @@ ComputationDataHandle ComputationBuilder::Tuple( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_variadic_op_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making variadic op request"; @@ -532,6 +549,7 @@ ComputationDataHandle ComputationBuilder::GetTupleElement( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_get_tuple_element_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making get tuple element op request"; @@ -681,14 +699,15 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( std::vector base_area_dimensions( dimension_numbers.spatial_dimensions_size()); - for (int i = 0; i < base_area_dimensions.size(); ++i) { + for (std::vector::size_type i = 0; i < base_area_dimensions.size(); + ++i) { base_area_dimensions[i] = lhs_shape->dimensions(dimension_numbers.spatial_dimensions(i)); } std::vector window_dimensions( dimension_numbers.kernel_spatial_dimensions_size()); - for (int i = 0; i < window_dimensions.size(); ++i) { + for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { window_dimensions[i] = rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } @@ -740,7 +759,7 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated( std::vector window_dimensions( dimension_numbers.kernel_spatial_dimensions_size()); - for (int i = 0; i < window_dimensions.size(); ++i) { + for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { window_dimensions[i] = rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } @@ -758,6 +777,7 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_convolve_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making convolve request"; @@ -777,6 +797,7 @@ ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_infeed_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making infeed op request"; @@ -786,6 +807,7 @@ ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, } void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, + const Shape& shape, const string& outfeed_config) { if (!first_error_.ok() || !PrepareComputation().ok()) { return; @@ -794,9 +816,11 @@ void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, OutfeedRequest request; request.set_outfeed_config(outfeed_config); *request.mutable_operand() = operand; + *request.mutable_shape() = shape; OpRequest op_request; *op_request.mutable_outfeed_request() = request; *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making outfeed op request"; @@ -823,6 +847,7 @@ ComputationDataHandle ComputationBuilder::Call( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_call_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making call op request"; @@ -848,6 +873,7 @@ ComputationDataHandle ComputationBuilder::CustomCall( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_custom_call_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making custom call op request"; @@ -950,22 +976,31 @@ ComputationDataHandle ComputationBuilder::Tanh( return UnaryOp(UNOP_TANH, operand); } +ComputationDataHandle ComputationBuilder::IsFinite( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_IS_FINITE, operand); +} + ComputationDataHandle ComputationBuilder::Transpose( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice permutation) { - if (!first_error_.ok()) { + if (!first_error_.ok() || !PrepareComputation().ok()) { return ComputationDataHandle(); } - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - // Just early return with the existing error status. - first_error_ = shape.status(); - return ComputationDataHandle(); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + TransposeRequest* request = op_request.mutable_transpose_request(); + *request->mutable_operand() = operand; + for (int64 dimension : permutation) { + request->add_dimensions(dimension); } - return Reshape(operand, permutation, - Permute(InversePermutation(permutation), - AsInt64Slice(shape.ValueOrDie()->dimensions()))); + AddOpMetadata(&op_request); + OpResponse response; + + VLOG(2) << "making transpose request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); } ComputationDataHandle ComputationBuilder::Rev( @@ -983,6 +1018,7 @@ ComputationDataHandle ComputationBuilder::Rev( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_reverse_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making reverse op request"; @@ -1003,8 +1039,9 @@ ComputationDataHandle ComputationBuilder::SqrtF32( } ComputationDataHandle ComputationBuilder::Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return BinaryOp(BINOP_POW, lhs, rhs, /*broadcast_dimensions=*/{}); + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions); } ComputationDataHandle ComputationBuilder::ConvertElementType( @@ -1027,6 +1064,7 @@ ComputationDataHandle ComputationBuilder::ConvertElementType( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_convert_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making convert request"; @@ -1070,6 +1108,7 @@ ComputationDataHandle ComputationBuilder::UnaryOp( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_unary_op_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making unop request"; @@ -1096,6 +1135,7 @@ ComputationDataHandle ComputationBuilder::BinaryOp( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_binary_op_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making binop request"; @@ -1121,6 +1161,7 @@ ComputationDataHandle ComputationBuilder::RngOp( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_rng_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making rngop request"; @@ -1144,6 +1185,7 @@ ComputationDataHandle ComputationBuilder::TernaryOp( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_ternary_op_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making triop request"; @@ -1192,8 +1234,7 @@ StatusOr ComputationBuilder::IsConstant( VLOG(2) << "done with request"; if (!s.ok()) { - NoteError(s); - return first_error_; + return s; } return response.is_constant(); } @@ -1218,8 +1259,7 @@ StatusOr> ComputationBuilder::ComputeConstant( VLOG(2) << "done with request"; if (!s.ok()) { - NoteError(s); - return first_error_; + return s; } TF_RET_CHECK(response.output().handle() != 0); @@ -1245,6 +1285,7 @@ ComputationDataHandle ComputationBuilder::Map( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_map_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making Map request"; @@ -1283,6 +1324,7 @@ ComputationDataHandle ComputationBuilder::While( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_while_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making while request"; @@ -1308,6 +1350,7 @@ ComputationDataHandle ComputationBuilder::Reduce( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_reduce_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making reduce request"; @@ -1360,6 +1403,7 @@ ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_reduce_window_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making reduce-window request"; @@ -1378,6 +1422,7 @@ ComputationDataHandle ComputationBuilder::CrossReplicaSum( OpRequest op_request; *op_request.mutable_cross_replica_sum_request() = request; *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making cross-replica-sum request"; @@ -1434,6 +1479,7 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_select_and_scatter_request() = request; + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making select-and-scatter request"; @@ -1453,11 +1499,12 @@ void ComputationBuilder::Send(const ComputationDataHandle& operand, OpRequest op_request; *op_request.mutable_send_request() = request; *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making send request"; - tensorflow::Status s = client_->stub()->Op(&op_request, &response); - VLOG(2) << "done with request"; + Status s = client_->stub()->Op(&op_request, &response); + VLOG(2) << "done with op request"; if (!s.ok()) { NoteError(s); @@ -1477,12 +1524,11 @@ ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, OpRequest op_request; *op_request.mutable_recv_request() = request; *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); OpResponse response; VLOG(2) << "making recv request"; - tensorflow::Status s = client_->stub()->Op(&op_request, &response); - VLOG(2) << "done with request"; - + Status s = client_->stub()->Op(&op_request, &response); return ParseOpResponse(s, &response); } @@ -1512,6 +1558,11 @@ StatusOr ComputationBuilder::Build() { return {std::move(computation_)}; } +void ComputationBuilder::AddOpMetadata(OpRequest* request) const { + tensorflow::mutex_lock lock(mutex_); + *request->mutable_metadata() = metadata_; +} + /* static */ ConvolutionDimensionNumbers ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { ConvolutionDimensionNumbers dimension_numbers; diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 67ca9c6cf74..5cc73c28d03 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/types.h" @@ -61,6 +62,23 @@ class ComputationBuilder { // Returns the computation name. const string& name() { return name_; } + // Sets OpMetadata that will be added to all instructions until cleared. + // + // OpMetadata is often applied to a series of XLA HLO instructions. As a + // result, OpMetadata is set on the Computation Builder. All subsequent + // instructions generated via this Computation Builder will have the same + // OpMetadata attached until a call to ClearOpMetdata. + void SetOpMetadata(const OpMetadata& metadata) { + tensorflow::mutex_lock lock(mutex_); + metadata_ = metadata; + } + + // Clears the HloMetdata state. + void ClearOpMetadata() { + tensorflow::mutex_lock lock(mutex_); + metadata_.Clear(); + } + // Sets the builder to a mode where it will die immediately when an error is // encountered, rather than producing it in a deferred fashion when Build() is // called (which is the default). @@ -193,9 +211,11 @@ class ComputationBuilder { // // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. + // The stride parameter determines the stride over the slice ComputationDataHandle Slice(const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices); + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice stride); // Enqueues a slice operation onto the computation that slices the 'operand' // from dynamic start indices which are passed in 'start_indices'. @@ -352,13 +372,13 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - // Enqueues an infeed instruction onto the computation, which reads data of - // the given shape from the infeed buffer of the device. + // Enqueues an infeed instruction onto the computation, which writes data of + // the given shape to the infeed buffer of the device. ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); // Enqueues an outfeed instruction onto the computation. This instruction // generates outgoing data transfers for the given data. - void Outfeed(const ComputationDataHandle& operand, + void Outfeed(const ComputationDataHandle& operand, const Shape& shape, const string& outfeed_config); // Enqueues a call instruction onto the computation. @@ -504,8 +524,15 @@ class ComputationBuilder { ComputationDataHandle SquareF32(const ComputationDataHandle& operand); // Enqueues a lhs^rhs computation onto the computation. - ComputationDataHandle Pow(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); + ComputationDataHandle Pow( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an operator that tests if the operand's values are finite, i.e., + // not Inf or NaN. Defined only for floating-point types. Returns an array of + // booleans with the same shape where entries are true iff the corresponding + // entry was NaN. + ComputationDataHandle IsFinite(const ComputationDataHandle& operand); // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. @@ -516,8 +543,8 @@ class ComputationBuilder { // (float32 is specified as there is an implicit float32 -1.0f constant // exponent). // - // TODO(leary) axe F32 suffix, can be determined by reflecting on the shape of - // the operand. + // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the + // shape of the operand. ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand); // Enqueues a negate instruction onto the computation. @@ -586,6 +613,48 @@ class ComputationBuilder { // computation. StatusOr IsConstant(const ComputationDataHandle& operand); + // Normalizes operand across spatial and batch dimensions for each feature. + // + // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` + // is the normalized result and batch_mean and batch_var are the mean and + // variance, respectively, across batch for the operand. + ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand, + const ComputationDataHandle& scale, + const ComputationDataHandle& offset, + float epsilon, int64 feature_index); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // `BatchNormInference` is equivalent to calling `BatchNormTraining` without + // computing `mean` and `variance` for each batch inside the operation. It + // uses the input `mean` and `variance` instead as estimated values. The + // purpose of this op is to reduce latency in inference, hence the name + // `BatchNormInference`. + // + // The output has the same shape as `operand`, and contains the normalized + // values for each batch. + ComputationDataHandle BatchNormInference( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& offset, const ComputationDataHandle& mean, + const ComputationDataHandle& variance, float epsilon, + int64 feature_index); + + // Calculates the gradients of a batch norm op. + // + // The inputs `batch_mean` and `batch_var` represent the mean and variance + // across the batch. + // + // Returns a tuple of three elements: + // - grad_operand: Gradient with respect to input `operand` + // - grad_offset: Gradient with respect to input `offset` + // - grad_scale: Gradient with respect to input `scale` + ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand, + const ComputationDataHandle& scale, + const ComputationDataHandle& batch_mean, + const ComputationDataHandle& batch_var, + const ComputationDataHandle& grad_output, + float epsilon, int64 feature_index); + // Computes the value of a constant indicated by a // ComputationDataHandle. // @@ -643,6 +712,14 @@ class ComputationBuilder { // then Build() should be used instead. Computation BuildAndNoteError(); + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // ComputationDataHandle and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + private: using PopulateLiteral = std::function; @@ -710,6 +787,8 @@ class ComputationBuilder { // * dying if die_immediately_on_error_ is true void NoteError(const Status& error); + void AddOpMetadata(OpRequest* request) const; + string name_; // Name to use for the built computation. // The first error encountered while building the computation. @@ -728,6 +807,14 @@ class ComputationBuilder { // Mode bit that indicates whether to die when a first error is encountered. bool die_immediately_on_error_{false}; + // Mutex to guard against concurrent access to metadata_. + mutable tensorflow::mutex mutex_; + + // The metadata to attach to each op. This is structured as a "modal"-like + // operation, in order to simplify client code (and not sprinkle this metadata + // throughout the TensorFlow op kernel implementations). + OpMetadata metadata_ GUARDED_BY(mutex_); + TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); }; @@ -804,7 +891,7 @@ template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR4FromArray4D(values, layout, literal); + LiteralUtil::PopulateR4FromArray4DWithLayout(values, layout, literal); }); } diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index be706f7d232..40f59eaa68e 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include +#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -23,7 +24,7 @@ limitations under the License. namespace xla { GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle) - : handle_(handle), parent_(parent) {} + : handle_(std::move(handle)), parent_(parent) {} GlobalData::~GlobalData() { UnregisterRequest request; diff --git a/tensorflow/compiler/xla/client/global_data.h b/tensorflow/compiler/xla/client/global_data.h index eb11d91034b..b7929357d06 100644 --- a/tensorflow/compiler/xla/client/global_data.h +++ b/tensorflow/compiler/xla/client/global_data.h @@ -23,13 +23,15 @@ limitations under the License. namespace xla { -// Wraps a GlobalDataHandle with a lifetime. +// A GlobalData object represents a globally-accessible allocation of +// data in the associated XLA service. class GlobalData { public: // Gives ownership of the global data handle to this object. GlobalData(ServiceInterface* parent, GlobalDataHandle handle); - // Unregisters the wrapped handle. + // Unregisters the wrapped handle, which causes the service to + // deallocate the associated data. ~GlobalData(); const GlobalDataHandle& handle() const { return handle_; } diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index e185beaedd3..86b16be62f0 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -19,6 +19,7 @@ cc_library( hdrs = ["arithmetic.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 31efd8ee64d..a45974b86b6 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -64,4 +65,33 @@ Computation CreateScalarMinComputation(PrimitiveType type, return b->BuildAndNoteError(); } +Computation CreateScalarLogicalAndComputation(ComputationBuilder* builder) { + const Shape scalar = ShapeUtil::MakeShape(PRED, {}); + auto b = builder->CreateSubBuilder("logical_and"); + auto lhs = b->Parameter(0, scalar, "lhs"); + auto rhs = b->Parameter(1, scalar, "rhs"); + b->LogicalAnd(lhs, rhs); + return b->BuildAndNoteError(); +} + +Computation CreateScalarLogicalOrComputation(ComputationBuilder* builder) { + const Shape scalar = ShapeUtil::MakeShape(PRED, {}); + auto b = builder->CreateSubBuilder("logical_or"); + auto lhs = b->Parameter(0, scalar, "lhs"); + auto rhs = b->Parameter(1, scalar, "rhs"); + b->LogicalOr(lhs, rhs); + return b->BuildAndNoteError(); +} + +StatusOr Any(const ComputationDataHandle& predicates, + ComputationBuilder* builder) { + auto f = builder->ConstantR0(false); + Computation logical_or = CreateScalarLogicalOrComputation(builder); + TF_ASSIGN_OR_RETURN(std::unique_ptr predicates_shape, + builder->GetShape(predicates)); + std::vector all_dimensions(ShapeUtil::Rank(*predicates_shape)); + std::iota(all_dimensions.begin(), all_dimensions.end(), 0); + return builder->Reduce(predicates, f, logical_or, all_dimensions); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 57fe7d74624..633086a2e7e 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -40,6 +40,18 @@ Computation CreateScalarMaxComputation(PrimitiveType type, Computation CreateScalarMinComputation(PrimitiveType type, ComputationBuilder* builder); +// Creates a scalar logical AND computation and returns it. +Computation CreateScalarLogicalAndComputation(ComputationBuilder* builder); + +// Creates a scalar logical OR computation and returns it. +Computation CreateScalarLogicalOrComputation(ComputationBuilder* builder); + +// Returns whether any predicate in "predicates" is set. +// +// Note: if predicates is zero-sized, Any() vacuously returns false. +StatusOr Any(const ComputationDataHandle& predicates, + ComputationBuilder* builder); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 384aae867b1..96944a53b7e 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/status_macros.h" namespace se = ::perftools::gputools; @@ -67,35 +68,13 @@ bool ExecutableBuildOptions::has_hybrid_result() const { } namespace { - -// Convenience class which holds an acquired stream from the backend and -// automatically releases it when destructed. -class StreamManager { - public: - static StatusOr> AcquireStream( - Backend* backend, int device_ordinal) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend->stream_executor(device_ordinal == -1 - ? backend->default_device_ordinal() - : device_ordinal)); - TF_ASSIGN_OR_RETURN(std::unique_ptr stream, - backend->AcquireStream(executor)); - return WrapUnique(new StreamManager(backend, std::move(stream))); +StatusOr BorrowStreamForDevice(int device_ordinal, + Backend* backend) { + if (device_ordinal < 0) { + device_ordinal = backend->default_device_ordinal(); } - - ~StreamManager() { backend_->ReleaseStream(std::move(stream_)); } - - se::Stream* stream() const { return stream_.get(); } - - private: - StreamManager(Backend* backend, std::unique_ptr stream) - : backend_(backend), stream_(std::move(stream)) {} - - Backend* backend_; - std::unique_ptr stream_; -}; - + return backend->BorrowStream(device_ordinal); +} } // namespace LocalExecutable::LocalExecutable(std::unique_ptr executable, @@ -108,7 +87,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, tensorflow::Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options) { + const ExecutableRunOptions& options, const Backend& backend) { const ComputationLayout& computation_layout = executable_->module_config().entry_computation_layout(); @@ -177,71 +156,54 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( run_executor->GetDeviceDescription().name().c_str()); } + if (!options.allocator()) { + return InvalidArgument("an allocator must be provided to ExecuteLocally"); + } + + if (options.allocator()->platform() != backend.platform()) { + return InvalidArgument( + "allocator platform (%s) does not match service platform (%s)", + options.allocator()->platform()->Name().c_str(), + backend.platform()->Name().c_str()); + } + return tensorflow::Status::OK(); } StatusOr> LocalExecutable::Run( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& options) { - TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options)); + TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_)); ExecutableRunOptions actual_options = options; - std::unique_ptr acquired_stream; if (options.stream() == nullptr) { TF_ASSIGN_OR_RETURN( - acquired_stream, - StreamManager::AcquireStream(backend_, options.device_ordinal())); - actual_options.set_stream(acquired_stream->stream()); + Backend::StreamPtr stream, + BorrowStreamForDevice(options.device_ordinal(), backend_)); + actual_options.set_stream(stream.get()); } if (options.allocator() == nullptr) { actual_options.set_allocator(backend_->memory_allocator()); } - if (executable_->dumping()) { - return ExecuteAndDump(&actual_options, arguments); - } - return executable_->ExecuteOnStream(&actual_options, arguments, - /*hlo_execution_profile=*/nullptr); -} - -tensorflow::Status LocalExecutable::Run( - const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options, ShapedBuffer* result) { - const ComputationLayout& computation_layout = - executable_->module_config().entry_computation_layout(); - TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options)); - - if (!computation_layout.result_layout().MatchesLayoutInShape( - result->shape())) { - return InvalidArgument( - "result buffer does not match shape or layout of computation result: " - "expected %s, got %s", - ShapeUtil::HumanString(computation_layout.result_layout().shape()) - .c_str(), - ShapeUtil::HumanString(result->shape()).c_str()); - } - - ExecutableRunOptions actual_options = options; - std::unique_ptr acquired_stream; - if (options.stream() == nullptr) { - TF_ASSIGN_OR_RETURN( - acquired_stream, - StreamManager::AcquireStream(backend_, options.device_ordinal())); - actual_options.set_stream(acquired_stream->stream()); - } - if (options.allocator() == nullptr) { - actual_options.set_allocator(backend_->memory_allocator()); - } + // For local client execution on CPU backends: + // *) The thread pool used for eigen CPU ops is from + // ExecutableRunOptions.eigen_intra_op_thread_pool. + // *) The thread pool used for XLA CPU ops is from + // backend_->eigen_intra_op_thread_pool(). + ServiceExecutableRunOptions service_options( + actual_options, backend_->StreamBorrower(), + backend_->eigen_intra_op_thread_pool()); if (executable_->dumping()) { - return Unimplemented("dumping execution not supported on this path"); + return ExecuteAndDump(&service_options, arguments); } - return executable_->ExecuteOnStream(&actual_options, arguments, result, - /*hlo_execution_profile=*/nullptr); + return executable_->ExecuteOnStreamWrapper>( + &service_options, options.execution_profile(), arguments); } StatusOr> LocalExecutable::ExecuteAndDump( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments) { executable_->session_module()->set_execution_platform( backend_->platform()->Name()); @@ -260,8 +222,9 @@ tensorflow::Status LocalExecutable::RecordArguments( SessionModule* session_module) { session_module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - TF_RETURN_IF_ERROR( - LiteralFromShapedBuffer(*argument, session_module->add_arguments())); + Literal literal; + TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal)); + *session_module->add_arguments() = literal.ToProto(); } return tensorflow::Status::OK(); } @@ -269,9 +232,13 @@ tensorflow::Status LocalExecutable::RecordArguments( tensorflow::Status LocalExecutable::RecordResult( const ShapedBuffer* result, SessionModule* session_module) { session_module->clear_result(); - return LiteralFromShapedBuffer(*result, session_module->mutable_result()); + Literal literal(session_module->result()); + TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal)); + *session_module->mutable_result() = literal.ToProto(); + return tensorflow::Status::OK(); } +// TODO(dnovillo) Change signature to return StatusOr. tensorflow::Status LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer, Literal* literal) { TF_ASSIGN_OR_RETURN( @@ -290,62 +257,6 @@ StatusOr> LocalClient::AllocateBufferOnDevice( return std::unique_ptr(new GlobalData(local_service_, handle)); } -tensorflow::Status LocalClient::ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs) { - return local_service_->ResolveArguments(arguments, device_ordinal, - argument_ptrs); -} - -StatusOr> LocalClient::ExecuteLocally( - const Computation& computation, - const tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options) { - return local_service_->ExecuteLocally(computation.handle(), arguments, - options); -} - -tensorflow::Status LocalClient::ExecuteLocally( - const Computation& computation, - const tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options, ShapedBuffer* result) { - return local_service_->ExecuteLocally(computation.handle(), arguments, - options, result); -} - -StatusOr>> -LocalClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options) { - std::vector service_instances; - service_instances.reserve(computations.size()); - for (const AheadOfTimeComputationInstance& instance : computations) { - service_instances.push_back({}); - LocalService::AheadOfTimeComputationInstance& service_instance = - service_instances.back(); - TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); - service_instance.argument_layouts = instance.argument_layouts; - service_instance.result_layout = instance.result_layout; - } - return local_service_->CompileAheadOfTime(service_instances, options); -} - -int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) { - llvm::Triple triple( - llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple))); - if (triple.isArch64Bit()) { - return 8; - } else if (triple.isArch32Bit()) { - return 4; - } else { - CHECK(triple.isArch16Bit()); - return 2; - } -} - se::Platform* LocalClient::platform() const { return local_service_->backend().platform(); } @@ -362,6 +273,14 @@ int LocalClient::default_device_ordinal() const { return local_service_->backend().default_device_ordinal(); } +const Backend& LocalClient::backend() const { + return local_service_->backend(); +} + +Backend* LocalClient::mutable_backend() { + return local_service_->mutable_backend(); +} + StatusOr> LocalClient::Compile( const Computation& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 33366b97fd5..c903cd27112 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -56,7 +56,7 @@ class ExecutableBuildOptions { // If set, this specifies the layout of the result of the computation. If not // set, the service will chose the layout of the result. A Shape is used to - // store the layout to accomodate tuple result shapes. A value of nullptr + // store the layout to accommodate tuple result shapes. A value of nullptr // indicates the option has not been set. ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); const Shape* result_layout() const; @@ -83,12 +83,6 @@ class LocalExecutable { const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& options); - // Overload which places the computation result in the given preallocated - // buffer. - tensorflow::Status Run( - const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options, ShapedBuffer* result); - // Return the layout (contained in a shape) of the result produced by the // computation. const Shape& result_layout() const { @@ -117,12 +111,12 @@ class LocalExecutable { // of the computation. tensorflow::Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options); + const ExecutableRunOptions& options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. StatusOr> ExecuteAndDump( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments); // Records the arguments used to invoke the computation in a SessionModule @@ -154,7 +148,7 @@ class LocalExecutable { const ExecutableBuildOptions& build_options_; }; -// An XLA service client object for use when the client and service run in +// An XLA Client specialization for use when the client and service run in // the same process. class LocalClient : public Client { public: @@ -164,14 +158,6 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // For an array of arguments held on the local service, validate - // that each is placed on the specified device_ordinal, and return - // the DeviceMemoryBase corresponding to each argument. - tensorflow::Status ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs); - // Return a handle to a buffer large enough to hold shape, allocated // on device_ordinal on the local service. If // allocate_space_for_deep_copy, the buffer is large enough to hold @@ -181,37 +167,6 @@ class LocalClient : public Client { const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy); - // Executes the given computation with the given arguments and - // options. Arguments and result are "zero-copy", and are passed as pointers - // to device memory. See LocalExecuteOptions class comments for description of - // available options. The returned ShapedBuffer includes pointer(s) to device - // memory (DeviceMemoryBase) which are the caller's responsibility to - // deallocate. The layout of the result is chosen by the XLA service and - // should not be relied upon to be a specific value. If a specific result - // layout is needed, then the layout should be set in options. - // - // The arrays of arguments with different shapes or layouts are assumed not to - // alias. - // - // TODO(b/31220873): Remove ExecuteLocally methods. The path forward is to use - // Compile and run the returned LocalExecutable. - StatusOr> ExecuteLocally( - const Computation& computation, - const tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options); - - // Overload of ExecuteLocally which writes the result into the given - // ShapedBuffer "result". Result is const because the ShapedBuffer data - // structure itself is not modified, only the buffers in device memory to - // which it refers. - // - // TODO(b/31220873): Remove ExecuteLocally methods. The path forward is to use - // Compile and run the returned LocalExecutable. - tensorflow::Status ExecuteLocally( - const Computation& computation, - const tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options, ShapedBuffer* result); - // Build and return a LocalExecutable object. The executable is compiled using // the given argument layouts and options. StatusOr> Compile( @@ -219,30 +174,6 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AheadOfTimeComputationInstance { - const Computation* computation; - // Inform the compiler of the expected layout for arguments. - std::vector argument_layouts; - // Specifies the expected result layout. - const Shape* result_layout; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - // - // TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its - // own library. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options); - - // Returns the size of a pointer in bytes for a given triple. - static int64 PointerSizeForTriple(tensorflow::StringPiece triple); - // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; @@ -261,6 +192,10 @@ class LocalClient : public Client { // capability). bool device_ordinal_supported(int device_ordinal) const; + // Returns the backend used to execute computations. + const Backend& backend() const; + Backend* mutable_backend(); + private: LocalService* local_service_; }; diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 281fa104084..0b18d8946a2 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -35,7 +35,7 @@ std::vector> MakePadding( return low_high_padding; case Padding::kSame: - for (int64 i = 0; i < input_dimensions.size(); ++i) { + for (size_t i = 0; i < input_dimensions.size(); ++i) { int64 input_dimension = input_dimensions[i]; int64 window_dimension = window_dimensions[i]; int64 window_stride = window_strides[i]; diff --git a/tensorflow/compiler/xla/differential_set.h b/tensorflow/compiler/xla/differential_set.h deleted file mode 100644 index 9eae24ce30e..00000000000 --- a/tensorflow/compiler/xla/differential_set.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_DIFFERENTIAL_SET_H_ -#define TENSORFLOW_COMPILER_XLA_DIFFERENTIAL_SET_H_ - -#include - -#include "tensorflow/core/platform/macros.h" - -namespace xla { - -// In the base case, the differential set is just a set. -// However, you can also point a differential set at another differential set to -// use as a "parent". This makes a chain of sets, which each node in the chain -// adds some number of elements to the "Contains" property. -// -// E.g. if the base set holds {1, 2}, you can create a derived set that holds -// {3}, and the derived set will tell you it contains {1, 2, 3} whereas the base -// will continue to tell you it holds only {1, 2}. -template -class DifferentialSet { - public: - // Constructs a differential set capable of holding values. It may have an - // ancestor link, which would it into a chain of sets. - explicit DifferentialSet(const DifferentialSet* parent = nullptr) - : parent_(parent) {} - - // Adds a value to be held directly by this set. - void Add(T value) { held_.insert(value); } - - // Returns whether this set holds the given value, or any ancestor in the - // chain of sets. - bool Contains(T value) const { - return held_.find(value) != held_.end() || - (parent_ != nullptr && parent_->Contains(value)); - } - - private: - // Values held directly by this node in the chain of sets. - std::set held_; - - // Parent node in the chain of sets. - const DifferentialSet* parent_; - - TF_DISALLOW_COPY_AND_ASSIGN(DifferentialSet); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_DIFFERENTIAL_SET_H_ diff --git a/tensorflow/compiler/xla/differential_set_test.cc b/tensorflow/compiler/xla/differential_set_test.cc deleted file mode 100644 index dacbbcc1adb..00000000000 --- a/tensorflow/compiler/xla/differential_set_test.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/differential_set.h" - -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -TEST(DifferentialSetTest, TellsWhetherSetContainsSomethingHeld) { - DifferentialSet set; - set.Add(1); - set.Add(2); - EXPECT_FALSE(set.Contains(3)); - EXPECT_TRUE(set.Contains(1)); - EXPECT_TRUE(set.Contains(2)); - EXPECT_FALSE(set.Contains(0)); -} - -TEST(DifferentialSetTest, TellsWhetherSetContainsSomethingParentHolds) { - DifferentialSet parent; - parent.Add(1); - DifferentialSet child(&parent); - child.Add(2); - - // Test properties of the child. - EXPECT_FALSE(child.Contains(3)); - EXPECT_TRUE(child.Contains(1)); - EXPECT_TRUE(child.Contains(2)); - EXPECT_FALSE(child.Contains(0)); - - // Test properties of the parent. - EXPECT_TRUE(parent.Contains(1)); - EXPECT_FALSE(parent.Contains(2)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 1c54fec97ce..67f3a6c1df4 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -67,4 +67,14 @@ const Eigen::ThreadPoolDevice* ExecutableRunOptions::intra_op_thread_pool() return intra_op_thread_pool_; } +ExecutableRunOptions& ExecutableRunOptions::set_execution_profile( + ExecutionProfile* profile) { + execution_profile_ = profile; + return *this; +} + +ExecutionProfile* ExecutableRunOptions::execution_profile() const { + return execution_profile_; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 212fce9eab7..03f2d016ad0 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -40,6 +40,7 @@ struct ThreadPoolDevice; namespace xla { class DeviceMemoryAllocator; +class ExecutionProfile; // Class containing options for running a LocalExecutable. class ExecutableRunOptions { @@ -74,12 +75,17 @@ class ExecutableRunOptions { const Eigen::ThreadPoolDevice* intra_op_thread_pool); const Eigen::ThreadPoolDevice* intra_op_thread_pool() const; + // If set, profiling information is written to 'profile'. + ExecutionProfile* execution_profile() const; + ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); + private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; perftools::gputools::Stream* stream_ = nullptr; tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; + ExecutionProfile* execution_profile_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 901fcd89ea2..76c0168f370 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -28,13 +28,13 @@ namespace xla { /* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex( const Shape& shape, tensorflow::gtl::ArraySlice multi_index) { - CHECK_EQ(shape.dimensions_size(), multi_index.size()); + DCHECK_EQ(shape.dimensions_size(), multi_index.size()); // Padding and nested layouts not supported yet. - CHECK_EQ(0, shape.layout().padded_dimensions_size()); + DCHECK_EQ(0, shape.layout().padded_dimensions_size()); - for (int i = 0; i < multi_index.size(); ++i) { - CHECK_GE(multi_index[i], 0); - CHECK_LT(multi_index[i], shape.dimensions(i)) + for (size_t i = 0; i < multi_index.size(); ++i) { + DCHECK_GE(multi_index[i], 0); + DCHECK_LT(multi_index[i], shape.dimensions(i)) << "indexing beyond extent in dimension " << i << ":" << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",") << "\n\tshape: " << ShapeUtil::HumanString(shape); @@ -77,9 +77,17 @@ namespace xla { // Scale factor holding the growing product of D{L(i)} terms. int64 scale = 1; int64 linear_index = 0; + bool first = true; for (auto dimension : shape.layout().minor_to_major()) { - linear_index += scale * multi_index[dimension]; - scale *= shape.dimensions(dimension); + if (first) { + // Avoid two multiplies on the first loop iteration + linear_index = multi_index[dimension]; + scale = shape.dimensions(dimension); + first = false; + } else { + linear_index += scale * multi_index[dimension]; + scale *= shape.dimensions(dimension); + } } return linear_index; } @@ -87,9 +95,9 @@ namespace xla { /* static */ std::vector IndexUtil::LinearIndexToMultidimensionalIndex( const Shape& shape, int64 linear_index) { // Padding and nested layouts not supported yet. - CHECK_EQ(0, shape.layout().padded_dimensions_size()); - CHECK_GE(linear_index, 0); - CHECK_LT(linear_index, ShapeUtil::ElementsIn(shape)); + DCHECK_EQ(0, shape.layout().padded_dimensions_size()); + DCHECK_GE(linear_index, 0); + DCHECK_LT(linear_index, ShapeUtil::ElementsIn(shape)); // The following formula computes each element of the multidimensional index // (See comments in MultidimensionalIndexToLinearIndex for notation): @@ -110,17 +118,36 @@ namespace xla { return multi_index; } -/* static */ bool IndexUtil::BumpIndices(const Shape& shape, - std::vector* indices) { - for (int64 dimno = indices->size() - 1; dimno >= 0; --dimno) { +/* static */ bool IndexUtil::BumpIndices( + const Shape& shape, tensorflow::gtl::MutableArraySlice indices) { + for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) { int64 limit = shape.dimensions(dimno); - if ((*indices)[dimno] + 1 < limit) { - (*indices)[dimno]++; - std::fill(indices->begin() + dimno + 1, indices->end(), 0); + if (indices[dimno] + 1 < limit) { + indices[dimno]++; + std::fill(indices.begin() + dimno + 1, indices.end(), 0); return true; } } return false; } +/* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape, + int64 dimension) { + const Layout& layout = shape.layout(); + int64 pdim_size = layout.padded_dimensions_size(); + int64 stride = 1; + DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size()); + for (auto dim : layout.minor_to_major()) { + if (dim == dimension) { + break; + } + if (pdim_size == 0) { + stride *= shape.dimensions(dim); + } else { + stride *= layout.padded_dimensions(dim); + } + } + return stride; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 2d8753c3fe8..c9838966a5b 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -58,7 +58,16 @@ class IndexUtil { // // Returns true iff the indices were successfully bumped; false if we've hit // the limit where it can no longer be bumped in-bounds. - static bool BumpIndices(const Shape& shape, std::vector* indices); + static bool BumpIndices(const Shape& shape, + tensorflow::gtl::MutableArraySlice indices); + + // Calculates the stride size (in number of elements, not byte size) of a + // given logical shape dimension (from 0 to rank-1). If available, padded + // dimensions are used. + // Example: + // GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) == + // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 + static int64 GetDimensionStride(const Shape& shape, int64 dimension); private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc index 85259b33f0b..7c4efdee484 100644 --- a/tensorflow/compiler/xla/index_util_test.cc +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -18,9 +18,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -144,14 +143,11 @@ TEST(IndexUtilTest, BumpIndices2x2) { auto shape = ShapeUtil::MakeShape(S32, {2, 2}); std::vector indices = {0, 0}; EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); - EXPECT_MATCH(indices, - testing::VectorMatcher(std::vector{0, 1})); + EXPECT_THAT(indices, ::testing::ElementsAre(0, 1)); EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); - EXPECT_MATCH(indices, - testing::VectorMatcher(std::vector{1, 0})); + EXPECT_THAT(indices, ::testing::ElementsAre(1, 0)); EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); - EXPECT_MATCH(indices, - testing::VectorMatcher(std::vector{1, 1})); + EXPECT_THAT(indices, ::testing::ElementsAre(1, 1)); EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices)); } diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 531a6e03dad..d3fcccff654 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -14,11 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/shape_util.h" - #include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -114,8 +113,8 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); EXPECT_FALSE(status.ok()); - EXPECT_MATCH(status.error_message(), - testing::ContainsRegex("cannot copy layout from shape")); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("cannot copy layout from shape")); } TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { @@ -133,8 +132,8 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); EXPECT_FALSE(status.ok()); - EXPECT_MATCH(status.error_message(), - testing::ContainsRegex("cannot copy layout from shape")); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("cannot copy layout from shape")); } TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { @@ -145,9 +144,10 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); EXPECT_FALSE(status.ok()); - EXPECT_MATCH(status.error_message(), - testing::ContainsRegex("layout minor_to_major field contains .* " - "elements, but shape is rank")); + EXPECT_THAT( + status.error_message(), + ::testing::ContainsRegex("layout minor_to_major field contains .* " + "elements, but shape is rank")); } TEST_F(LayoutUtilTest, ClearLayoutTuple) { diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 79ff81262e9..a147ce67a28 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -29,6 +29,7 @@ cc_library( cc_test( name = "parse_flags_from_env_test", + size = "small", srcs = ["parse_flags_from_env_test.cc"], deps = [ @@ -65,6 +66,20 @@ cc_library( ], ) +cc_library( + name = "debug_options_flags", + srcs = ["debug_options_flags.cc"], + hdrs = ["debug_options_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "cpu_compiler_flags", srcs = ["cpu_compiler_flags.cc"], @@ -160,18 +175,6 @@ cc_library( ], ) -cc_library( - name = "hlo_pass_pipeline_flags", - srcs = ["hlo_pass_pipeline_flags.cc"], - hdrs = ["hlo_pass_pipeline_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "alias_analysis_flags", srcs = ["alias_analysis_flags.cc"], @@ -240,6 +243,18 @@ cc_library( ], ) +cc_library( + name = "user_computation_flags", + srcs = ["user_computation_flags.cc"], + hdrs = ["user_computation_flags.h"], + deps = [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc index f8ae25552d4..13d41a8636b 100644 --- a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc @@ -36,23 +36,15 @@ static std::once_flag flags_init; // Allocate *flags. Called via call_once(&flags_init,...). static void AllocateFlags() { flags = new CpuCompilerFlags; - flags->xla_cpu_llvm_opt_level = 2; - flags->xla_cpu_llvm_cl_opts = ""; flags->xla_cpu_embed_ir = false; - flags->xla_cpu_parallel = false; + flags->xla_cpu_dump_debug_json_to = ""; flag_list = new std::vector({ - tensorflow::Flag( - "xla_cpu_llvm_opt_level", &flags->xla_cpu_llvm_opt_level, - "The LLVM optimization level for the CPU XLA backend. " - "Valid range is from 0 to 3 where 0 means no optimizations."), - tensorflow::Flag( - "xla_cpu_llvm_cl_opts", &flags->xla_cpu_llvm_cl_opts, - "Comma-separated list of command line options to pass to LLVM."), tensorflow::Flag( "xla_cpu_embed_ir", &flags->xla_cpu_embed_ir, "Embed the LLVM IR module string in the resultant CpuExecutable."), - tensorflow::Flag("xla_cpu_parallel", &flags->xla_cpu_parallel, - "Use the multi-threaded CPU backend."), + tensorflow::Flag("xla_cpu_dump_debug_json_to", + &flags->xla_cpu_dump_debug_json_to, + "Dump debug JSON to this directory."), }); ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h index 16a7b687116..bac498e18eb 100644 --- a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h @@ -33,14 +33,9 @@ void AppendCpuCompilerFlags(std::vector* flag_list); // The values of flags associated with XLA's cpu_compiler module. typedef struct { - // The LLVM optimization level for the CPU XLA backend. - // Valid range is from 0 to 3 where 0 means no optimizations. - int32 xla_cpu_llvm_opt_level; - string xla_cpu_llvm_cl_opts; // Comma-separated list of command line options - // to pass to LLVM. bool xla_cpu_embed_ir; // Embed the LLVM IR module string in the resultant // CpuExecutable - bool xla_cpu_parallel; // Use the multi-threaded CPU backend. + string xla_cpu_dump_debug_json_to; // Dump debug JSON to this directory. } CpuCompilerFlags; // Return a pointer to the CpuCompilerFlags struct; diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc new file mode 100644 index 00000000000..5e3c4f912bf --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace legacy_flags { + +struct DebugOptionsFlags { + string xla_generate_hlo_graph; + string xla_disable_hlo_passes; + bool xla_enable_fast_math; + int32 xla_backend_optimization_level; + string xla_backend_extra_options; +}; + +namespace { + +DebugOptionsFlags* flag_values; +std::vector* flag_objects; +std::once_flag flags_init; + +// Allocates flag_values and flag_objects; this function must not be called more +// than once - its call done via call_once. +void AllocateFlags() { + flag_values = new DebugOptionsFlags; + flag_values->xla_generate_hlo_graph = ""; + flag_values->xla_disable_hlo_passes = ""; + flag_values->xla_enable_fast_math = true; + flag_values->xla_backend_optimization_level = 2; + flag_values->xla_backend_extra_options = ""; + + flag_objects = new std::vector( + {tensorflow::Flag( + "xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph, + "HLO modules matching this regex will be dumped to a .dot file " + "throughout various stages in compilation."), + + tensorflow::Flag( + "xla_enable_fast_math", &flag_values->xla_enable_fast_math, + "Enable unsafe fast-math optimizations in the compiler; " + "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_backend_optimization_level", + &flag_values->xla_backend_optimization_level, + "Numerical optimization level for the XLA compiler backend."), + + tensorflow::Flag("xla_backend_extra_options", + &flag_values->xla_backend_extra_options, + "Extra options to pass to a backend; " + "comma-separated list of 'key=val' strings (=val " + "may be omitted); no whitespace around commas."), + + tensorflow::Flag( + "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes, + "Comma-separated list of HLO passes to be disabled. These names " + "must exactly match the passes' names; " + "no whitespace around commas.")}); + ParseFlagsFromEnv(*flag_objects); +} + +} // namespace + +void AppendDebugOptionsFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateFlags); + flag_list->insert(flag_list->end(), flag_objects->begin(), + flag_objects->end()); +} + +xla::DebugOptions GetDebugOptionsFromFlags() { + std::call_once(flags_init, &AllocateFlags); + + DebugOptions options; + options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph); + + std::vector disabled_passes = + tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ','); + for (const auto& passname : disabled_passes) { + options.add_xla_disable_hlo_passes(passname); + } + + options.set_xla_enable_fast_math(flag_values->xla_enable_fast_math); + options.set_xla_backend_optimization_level( + flag_values->xla_backend_optimization_level); + + std::vector extra_options_parts = + tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ','); + auto* extra_options_map = options.mutable_xla_backend_extra_options(); + + // The flag contains a comma-separated list of options; some options have + // arguments following "=", some don't. + for (const auto& part : extra_options_parts) { + size_t eq_pos = part.find_first_of('='); + if (eq_pos == string::npos) { + (*extra_options_map)[part] = ""; + } else { + string value = ""; + if (eq_pos + 1 < part.size()) { + value = part.substr(eq_pos + 1); + } + (*extra_options_map)[part.substr(0, eq_pos)] = value; + } + } + + return options; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h new file mode 100644 index 00000000000..d0ef8e66ab0 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ + +#include + +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Appends flag definitions for debug options to flag_list. +void AppendDebugOptionsFlags(std::vector* flag_list); + +// Fetches a DebugOptions proto message from flags provided to the program. +// Flags must be registered with the flags parser using AppendDebugOptionsFlags +// first. +xla::DebugOptions GetDebugOptionsFromFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc index c355b1ed9b7..f8f6ea26b1d 100644 --- a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc @@ -38,7 +38,6 @@ static void AllocateFlags() { flags->dump_temp_products_to = ""; flags->ftz = false; flags->fma = true; - flags->gpu_architecture = "compute_35"; flags->verbose_ptx_asm = false; flags->kernel = ""; flags->llvm_dump_passes = false; @@ -51,8 +50,6 @@ static void AllocateFlags() { "If empty, no dump is produced"), tensorflow::Flag("ftz", &flags->ftz, "flush to zero semantics"), tensorflow::Flag("fma", &flags->fma, "use FMA synthesis"), - tensorflow::Flag("gpu_architecture", &flags->gpu_architecture, - "GPU architecture"), tensorflow::Flag("verbose_ptx_asm", &flags->verbose_ptx_asm, "emit PTX assembly with extra comments"), tensorflow::Flag("kernel", &flags->kernel, diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h index fbb88634545..31cb50e9da9 100644 --- a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h @@ -36,7 +36,6 @@ typedef struct { string dump_temp_products_to; // temporary compilation products dir bool ftz; // flush to zero semantics bool fma; // use FMA synthesis - string gpu_architecture; // GPU architecture bool verbose_ptx_asm; // emit PTX assembly with extra comments string kernel; // only emit the IR and PTX for this kernel bool llvm_dump_passes; // dump the passes LLVM runs to stderr diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc index e79d3635095..131e3ce70ac 100644 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc @@ -38,7 +38,7 @@ static void AllocateFlags() { flags = new GpuCompilerFlags; flags->xla_gpu_embed_ir = false; flags->xla_cuda_data_dir = "./cuda_sdk_lib"; - flags->xla_ptxas_path = "/usr/local/cuda/bin/ptxas"; + flags->xla_gpu_dump_debug_json_to = ""; flag_list = new std::vector({ tensorflow::Flag( "xla_gpu_embed_ir", &flags->xla_gpu_embed_ir, @@ -50,6 +50,9 @@ static void AllocateFlags() { "runfile directories."), tensorflow::Flag("xla_ptxas_path", &flags->xla_ptxas_path, "The path to ptxas. Required to log stats of the ptx."), + tensorflow::Flag("xla_gpu_dump_debug_json_to", + &flags->xla_gpu_dump_debug_json_to, + "Dump debug JSON to this directory."), }); ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h index 04ddedab732..0cf39e0ab35 100644 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h @@ -41,6 +41,7 @@ typedef struct { // directories. string xla_ptxas_path; // The path to ptxas. Required to log stats of // the ptx. + string xla_gpu_dump_debug_json_to; // Dump debug JSON to this directory. } GpuCompilerFlags; // Return a pointer to the GpuCompilerFlags struct; diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc index 8822f6f6107..ba43a591952 100644 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc @@ -36,10 +36,14 @@ static std::once_flag flags_init; static void AllocateFlags() { flags = new HloGraphDumperFlags; flags->xla_hlo_dump_graph_path = "/tmp/"; + flags->xla_hlo_dump_as_graphdef = false; flag_list = new std::vector({ tensorflow::Flag("xla_hlo_dump_graph_path", &flags->xla_hlo_dump_graph_path, "Path to write dumped HLO graphs to"), + tensorflow::Flag("xla_hlo_dump_as_graphdef", + &flags->xla_hlo_dump_as_graphdef, + "Dumps HLO graphs as tensorflow GraphDefs"), }); ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h index b6dfced87ca..d0b4d092ff1 100644 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h @@ -34,6 +34,9 @@ void AppendHloGraphDumperFlags(std::vector* flag_list); // The values of flags associated with XLA's hlo_graph_dumper module. typedef struct { string xla_hlo_dump_graph_path; // Path to write dumped HLO graphs to + // If set, dumps HLO graphs as tensorflow GraphDef; otherwise, dumps HLO + // graphs as DOT graph. + bool xla_hlo_dump_as_graphdef; } HloGraphDumperFlags; // Return a pointer to the HloGraphDumperFlags struct; diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc deleted file mode 100644 index edc04d51a70..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's hlo_pass_pipeline module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static HloPassPipelineFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new HloPassPipelineFlags; - flags->xla_disable_hlo_passes = ""; - flag_list = new std::vector({ - tensorflow::Flag("xla_disable_hlo_passes", &flags->xla_disable_hlo_passes, - "Comma-separated list of HLO passes to disable."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline -// module. -void AppendHloPassPipelineFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the HloPassPipelineFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloPassPipelineFlags* GetHloPassPipelineFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h deleted file mode 100644 index 520759bbf0d..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_ - -// Legacy flags for XLA's hlo_pass_pipeline module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's hlo_pass_pipeline -// module. -void AppendHloPassPipelineFlags(std::vector* flag_list); - -// The values of flags associated with XLA's hlo_pass_pipeline module. -typedef struct { - // Comma-separated list of HLO passes to disable. - string xla_disable_hlo_passes; -} HloPassPipelineFlags; - -// Return a pointer to the HloPassPipelineFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloPassPipelineFlags* GetHloPassPipelineFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc index 4242b501d41..f838861898d 100644 --- a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc @@ -53,7 +53,7 @@ static void AllocateRawFlag() { static bool ParseDefaultLayout(const string& text, DefaultLayout* layout) { bool result = true; std::vector field = tensorflow::str_util::Split(text, ':'); - if (field.size() > 0) { + if (!field.empty()) { if (field[0] == "random") { layout->dimension_order = DefaultLayout::DimensionOrder::kRandom; if (field.size() > 1) { diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc new file mode 100644 index 00000000000..a9597d0cd8f --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static UserComputationFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new UserComputationFlags; + flags->xla_eliminate_hlo_implicit_broadcast = false; + flag_list = new std::vector({ + tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast", + &flags->xla_eliminate_hlo_implicit_broadcast, + "Eliminate implicit broadcast on when lowering user " + "computation to HLO instructions, use explicit " + "broadcast instead."), + }); + ParseFlagsFromEnv(*flag_list); +} + +// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline +// module. +void AppendUserComputationFlags(std::vector* append_to) { + std::call_once(flags_init, &AllocateFlags); + append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); +} + +// Return a pointer to the UserComputationFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +UserComputationFlags* GetUserComputationFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h new file mode 100644 index 00000000000..f5222c927cb --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ + +// Legacy flags for XLA's user_computation module. + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Append to *flag_list flags definitions associated with XLA's user_computation +// module. +void AppendUserComputationFlags(std::vector* flag_list); + +typedef struct { + // Eliminate implicit broadcast on when lowering user computation to HLO + // instructions, use explicit broadcast instead. + bool xla_eliminate_hlo_implicit_broadcast; +} UserComputationFlags; + +// Return a pointer to the UserComputationFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +UserComputationFlags* GetUserComputationFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index b8bb56a97b2..caef3a3869f 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -16,12 +16,15 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include +#include +#include #include #include #include #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -33,31 +36,151 @@ limitations under the License. namespace xla { -/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { +Literal::StrideConfig::StrideConfig( + const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions) + : dimensions(dimensions), + base(dimensions.size(), 0), + step(dimensions.size(), 1) { + if (!dimensions.empty()) { + // Selects the shape with the highest minor dimension as the one upon + // where to run the tight stride loop. + if (source_shape.layout().minor_to_major()[0] >= + dest_shape.layout().minor_to_major()[0]) { + minor_dimension = source_shape.layout().minor_to_major()[0]; + dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); + } else { + minor_dimension = dest_shape.layout().minor_to_major()[0]; + source_stride = + IndexUtil::GetDimensionStride(source_shape, minor_dimension); + } + minor_loop_size = dimensions[minor_dimension]; + step[minor_dimension] = minor_loop_size; + } +} + +std::unique_ptr Literal::CreateFromShape(const Shape& shape) { + auto literal = MakeUnique(); + *literal->mutable_shape() = shape; + literal->Reserve(ShapeUtil::ElementsIn(literal->shape())); + return literal; +} + +/* static */ std::unique_ptr Literal::CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions) { + return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); +} + +template +Status Literal::CopyRange(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + const Shape& src_shape = src_literal.shape(); + const Shape& dest_shape = shape(); + tensorflow::gtl::ArraySlice src_data = src_literal.GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = GetMutableArraySlice(); + + TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); + if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) { + // If any of the two shapes are scalars, we can just call the StridedCopy() + // directly, and we know we will be copying only one value. + TF_RET_CHECK(copy_size.empty()); + StridedCopy(dest_data, LinearIndex(dest_base), 0, src_data, + src_literal.LinearIndex(src_base), 0, 1); + } else if (!ShapeUtil::HasZeroElements(dest_shape)) { + TF_RET_CHECK(!ShapeUtil::HasZeroElements(src_shape)); + TF_RET_CHECK(src_base.size() == dest_base.size()); + TF_RET_CHECK(src_base.size() == copy_size.size()); + + // Scan the source from minor, stepping in copy size blocks, then within + // the index enumaration functor, do a strided copy advancing source index + // by one (walking through the minor dimension), and destination index by + // proper stride size at the matching dimension. + DimensionVector src_indexes(src_base.size(), 0); + DimensionVector dest_indexes(dest_base.size(), 0); + StrideConfig stride_config(src_shape, dest_shape, copy_size); + + auto copy_proc = [&](const std::vector& indexes) { + // Map from multi-dimensional index, to source index. + std::transform(indexes.begin(), indexes.end(), src_base.begin(), + src_indexes.begin(), std::plus()); + // Map from multi-dimensional index, to destination index. + std::transform(indexes.begin(), indexes.end(), dest_base.begin(), + dest_indexes.begin(), std::plus()); + + int64 src_index = src_literal.LinearIndex(src_indexes); + int64 dest_index = LinearIndex(dest_indexes); + + StridedCopy(dest_data, dest_index, stride_config.dest_stride, src_data, + src_index, stride_config.source_stride, + stride_config.minor_loop_size); + return true; + }; + + ShapeUtil::ForEachIndex(src_shape, stride_config.base, + stride_config.dimensions, stride_config.step, + copy_proc); + } + return Status::OK(); +} + +Status Literal::Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); + switch (src_literal.shape().element_type()) { + case U32: + return CopyRange(src_literal, src_base, dest_base, copy_size); + case U64: + return CopyRange(src_literal, src_base, dest_base, copy_size); + case S32: + return CopyRange(src_literal, src_base, dest_base, copy_size); + case S64: + return CopyRange(src_literal, src_base, dest_base, copy_size); + case F16: + return CopyRange(src_literal, src_base, dest_base, copy_size); + case F32: + return CopyRange(src_literal, src_base, dest_base, copy_size); + case F64: + return CopyRange(src_literal, src_base, dest_base, copy_size); + case PRED: + return CopyRange(src_literal, src_base, dest_base, copy_size); + default: + break; + } + return Unimplemented("Unhandled primitive type %d", + src_literal.shape().element_type()); +} + +/* static */ Literal Literal::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case U32: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case U64: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case S8: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case S32: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case S64: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); + case F16: + return *Literal::CreateR0(static_cast(0.0f)); case F32: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case F64: - return *LiteralUtil::CreateR0(0); + return *Literal::CreateR0(0); case PRED: - return *LiteralUtil::CreateR0(false); + return *Literal::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - LOG(FATAL) << "f16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 0"; case OPAQUE: @@ -67,31 +190,31 @@ namespace xla { } } -/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { +/* static */ Literal Literal::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case U32: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case U64: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case S8: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case S32: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case S64: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case F32: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case F64: - return *LiteralUtil::CreateR0(1); + return *Literal::CreateR0(1); case PRED: - return *LiteralUtil::CreateR0(true); + return *Literal::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *Literal::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -101,33 +224,33 @@ namespace xla { } } -/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { +/* static */ Literal Literal::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case U32: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case U64: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case S8: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case S32: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case S64: - return *LiteralUtil::CreateR0(std::numeric_limits::min()); + return *Literal::CreateR0(std::numeric_limits::min()); case F32: - return *LiteralUtil::CreateR0( - -std::numeric_limits::infinity()); + return *Literal::CreateR0(-std::numeric_limits::infinity()); case F64: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( -std::numeric_limits::infinity()); case PRED: - return *LiteralUtil::CreateR0(false); + return *Literal::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *Literal::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -137,33 +260,33 @@ namespace xla { } } -/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { +/* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case U32: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case U64: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case S8: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case S32: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case S64: - return *LiteralUtil::CreateR0(std::numeric_limits::max()); + return *Literal::CreateR0(std::numeric_limits::max()); case F32: - return *LiteralUtil::CreateR0( - std::numeric_limits::infinity()); + return *Literal::CreateR0(std::numeric_limits::infinity()); case F64: - return *LiteralUtil::CreateR0( + return *Literal::CreateR0( std::numeric_limits::infinity()); case PRED: - return *LiteralUtil::CreateR0(true); + return *Literal::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *Literal::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -173,191 +296,161 @@ namespace xla { } } -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ std::unique_ptr Literal::CreateR1( const tensorflow::core::Bitmap& values) { auto literal = MakeUnique(); - PopulateR1(values, literal.get()); + literal->PopulateR1(values); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR1U8( +/* static */ std::unique_ptr Literal::CreateR1U8( tensorflow::StringPiece value) { auto literal = MakeUnique(); *literal->mutable_shape() = ShapeUtil::MakeShape(U8, {static_cast(value.size())}); - literal->set_u8s(value.ToString()); + literal->set_u8s(tensorflow::StringPiece(value.ToString())); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( - float from, float to, int64 rows, int64 cols) { +/* static */ std::unique_ptr Literal::CreateR2F32Linspace(float from, + float to, + int64 rows, + int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -/* static */ std::unique_ptr LiteralUtil::Relayout( - const Literal& original, const Layout& layout) { - // Note: if this were a performance bottleneck, we avoid cloning and just make - // an uninitialized array instead, since all values are clobbered below. - std::unique_ptr result = CloneToUnique(original); +std::unique_ptr Literal::Relayout(const Layout& layout) const { + std::unique_ptr result = CloneToUnique(); *result->mutable_shape()->mutable_layout() = layout; - const PrimitiveType primitive_type = original.shape().element_type(); - switch (primitive_type) { - case F32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, float value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - case S32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, int32 value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - case U32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, uint32 value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(primitive_type); - } + + DimensionVector base(ShapeUtil::Rank(shape()), 0); + DimensionVector copy_size(shape().dimensions().begin(), + shape().dimensions().end()); + + TF_CHECK_OK(result->Copy(*this, base, base, copy_size)); + return result; } -/* static */ StatusOr> LiteralUtil::Reshape( - const xla::Literal& input, tensorflow::gtl::ArraySlice dimensions) { - if (ShapeUtil::IsTuple(input.shape())) { +StatusOr> Literal::Reshape( + tensorflow::gtl::ArraySlice dimensions) const { + if (ShapeUtil::IsTuple(shape())) { return InvalidArgument("Reshape does not support tuples."); } - - if (!LayoutUtil::IsMonotonicWithDim0Major(input.shape().layout())) { - return Unimplemented( - "Input shape must have a monotonic layout where dimension 0 is major, " - "was: %s", - LayoutUtil::HumanString(input.shape().layout()).c_str()); + std::unique_ptr output; + if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { + std::vector minor_to_major(ShapeUtil::Rank(shape())); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), + static_cast(0)); + output = Relayout(LayoutUtil::MakeLayout(minor_to_major)); + } else { + output = CloneToUnique(); } - std::vector layout(dimensions.size()); - std::iota(layout.rbegin(), layout.rend(), 0); - // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - std::unique_ptr output = CloneToUnique(input); - output->clear_shape(); - output->mutable_shape()->set_element_type(input.shape().element_type()); - for (int64 dimension : dimensions) { - output->mutable_shape()->add_dimensions(dimension); - } - *output->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(layout); + *output->mutable_shape() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); - int64 elements_before = ShapeUtil::ElementsIn(input.shape()); + int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); if (elements_before != elements_after) { return InvalidArgument( - "Shapes before and after LiteralUtil::Reshape have different numbers " + "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", - ShapeUtil::HumanString(input.shape()).c_str(), + ShapeUtil::HumanString(shape()).c_str(), ShapeUtil::HumanString(output->shape()).c_str()); } return std::move(output); } -/* static */ std::unique_ptr LiteralUtil::Transpose( - const Literal& original, tensorflow::gtl::ArraySlice permutation) { - CHECK(!ShapeUtil::IsTuple(original.shape())) - << "tuple is not supported for transpose"; - std::vector dimension_numbers(ShapeUtil::Rank(original.shape())); - std::iota(dimension_numbers.begin(), dimension_numbers.end(), 0); - CHECK(std::is_permutation(permutation.begin(), permutation.end(), - dimension_numbers.begin())) - << "given permutation is not a permutation of dimension numbers"; - std::vector new_dimension_sizes; - for (const int64 dim : permutation) { - new_dimension_sizes.push_back(original.shape().dimensions(dim)); - } - const auto result_shape = ShapeUtil::MakeShape( - original.shape().element_type(), new_dimension_sizes); - std::unique_ptr result = CloneToUnique(original); - *result->mutable_shape() = result_shape; - const PrimitiveType primitive_type = original.shape().element_type(); - std::vector new_indices(ShapeUtil::Rank(original.shape())); - switch (primitive_type) { - case F32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, float value) { - for (int64 i = 0; i < permutation.size(); ++i) { - new_indices[i] = indices[permutation[i]]; - } - LiteralUtil::Set(result.get(), new_indices, value); - }); - return result; - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(primitive_type); +std::unique_ptr Literal::Transpose( + tensorflow::gtl::ArraySlice permutation) const { + CHECK(!ShapeUtil::IsTuple(shape())) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) + << "Given permutation is not a permutation of dimension numbers"; + // To transpose the array, we just permute the dimensions and layout, and + // do a straight memory copy of the raw data set. + // This is considerably faster than iterating over every array element using + // the EachCell<>() and Set<>() APIs. + std::vector inverse_permutation = InversePermutation(permutation); + Shape permuted_shape = + ShapeUtil::PermuteDimensions(inverse_permutation, shape()); + // Replace the layout with one affine to this shape, such that a + // transpose operation can be performed by leaving the flat values + // representation intact. + // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. + // The shape with affine layout resulting from that operation will be + // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the + // most minor. + // Essentially, given MinMaj(Di) the position of the Di dimension within the + // minor to major vector, and given T(Di) the index that the original Di + // dimension has within the transposed array, a layout is affine if + // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major + // vector of the affine layout. + Layout* layout = permuted_shape.mutable_layout(); + layout->clear_minor_to_major(); + for (auto index : shape().layout().minor_to_major()) { + layout->add_minor_to_major(inverse_permutation[index]); } + std::unique_ptr new_literal = CreateFromShape(permuted_shape); + DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + ShapeUtil::ByteSizeOf(shape())); + std::memcpy(new_literal->MutableInternalData(), InternalData(), + ShapeUtil::ByteSizeOf(shape())); + return new_literal; } -/* static */ std::unique_ptr LiteralUtil::Slice( - const Literal& literal, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) { - CHECK(!ShapeUtil::IsTuple(literal.shape())) - << "tuple is not supported for reshape"; +std::unique_ptr Literal::Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const { + CHECK(!ShapeUtil::IsTuple(shape())) << "tuple is not supported for reshape"; - std::vector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(literal.shape()); ++dnum) { + DimensionVector result_dimensions; + for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { CHECK_GE(start_indices[dnum], 0); - CHECK_LE(limit_indices[dnum], literal.shape().dimensions(dnum)); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)); int64 dimension = limit_indices[dnum] - start_indices[dnum]; CHECK_GT(dimension, 0); result_dimensions.push_back(dimension); } const auto result_shape = ShapeUtil::MakeShapeWithLayout( - literal.shape().element_type(), result_dimensions, - AsInt64Slice(literal.shape().layout().minor_to_major())); + shape().element_type(), result_dimensions, + AsInt64Slice(shape().layout().minor_to_major())); auto result_literal = MakeUnique(); *result_literal->mutable_shape() = result_shape; - Reserve(ShapeUtil::ElementsIn(result_shape), result_literal.get()); + result_literal->Reserve(ShapeUtil::ElementsIn(result_shape)); - std::vector new_indices(ShapeUtil::Rank(result_shape)); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: - LiteralUtil::EachCell( - *result_literal, + result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } - float value = LiteralUtil::Get(literal, new_indices); - LiteralUtil::Set(result_literal.get(), indices, value); + float value = Get(new_indices); + result_literal->Set(indices, value); }); return result_literal; case S32: - LiteralUtil::EachCell( - *result_literal, + result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } - int32 value = LiteralUtil::Get(literal, new_indices); - LiteralUtil::Set(result_literal.get(), indices, value); + int32 value = Get(new_indices); + result_literal->Set(indices, value); }); return result_literal; case U32: - LiteralUtil::EachCell( - *result_literal, + result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } - uint32 value = LiteralUtil::Get(literal, new_indices); - LiteralUtil::Set(result_literal.get(), indices, value); + uint32 value = Get(new_indices); + result_literal->Set(indices, value); }); return result_literal; default: @@ -366,96 +459,95 @@ namespace xla { } } -/* static */ std::unique_ptr LiteralUtil::CloneToUnique( - const Literal& literal) { +std::unique_ptr Literal::CloneToUnique() const { auto unique = MakeUnique(); - *unique = literal; + *unique = *this; return unique; } -/* static */ string LiteralUtil::GetAsString( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - switch (literal.shape().element_type()) { +string Literal::GetAsString( + tensorflow::gtl::ArraySlice multi_index) const { + switch (shape().element_type()) { case PRED: - return Get(literal, multi_index) ? "true" : "false"; + return Get(multi_index) ? "true" : "false"; case U8: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case S32: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case S64: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case U32: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case U64: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case F32: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); case F64: - return tensorflow::strings::StrCat(Get(literal, multi_index)); + return tensorflow::strings::StrCat(Get(multi_index)); + case F16: + return tensorflow::strings::StrCat(Get(multi_index)); default: return tensorflow::strings::StrCat( - "[", PrimitiveType_Name(literal.shape().element_type()), "]"); + "[", PrimitiveType_Name(shape().element_type()), "]"); } } -/* static */ int64 LiteralUtil::LinearIndex( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), - multi_index); +int64 Literal::LinearIndex( + tensorflow::gtl::ArraySlice multi_index) const { + return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); } -/* static */ string LiteralUtil::ToString(const Literal& literal) { - const Shape& shape = literal.shape(); +string Literal::ToString() const { std::vector pieces; auto element_to_string = - [&literal](tensorflow::gtl::ArraySlice indices) -> string { - PrimitiveType element_type = literal.shape().element_type(); + [this](tensorflow::gtl::ArraySlice indices) -> string { + PrimitiveType element_type = shape().element_type(); if (element_type == PRED) { // We display predicates in a densely packed form. - return Get(literal, indices) ? "1" : "0"; + return Get(indices) ? "1" : "0"; } return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - GetAsString(literal, indices); + GetAsString(indices); }; // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(shape)) { - pieces.push_back(ShapeUtil::HumanString(shape)); + if (ShapeUtil::IsTuple(shape())) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" (\n"); - for (const auto& element_literal : literal.tuple_literals()) { - pieces.push_back(ToString(element_literal)); + for (const auto& element_literal : tuple_literals()) { + pieces.push_back(element_literal.ToString()); pieces.push_back(",\n"); } pieces.push_back(")"); - } else if (ShapeUtil::Rank(shape) == 0) { - pieces.push_back(GetAsString(literal, {})); - } else if (ShapeUtil::Rank(shape) == 1) { + } else if (ShapeUtil::Rank(shape()) == 0) { + pieces.push_back(GetAsString({})); + } else if (ShapeUtil::Rank(shape()) == 1) { pieces.push_back("{"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(element_to_string({i0})); } pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape) == 2) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 2) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(" { "); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back(element_to_string({i0, i1})); } pieces.push_back(" "); pieces.push_back("},\n"); } pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape) == 3) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 3) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back(element_to_string({i0, i1, i2})); } pieces.push_back(" }"); @@ -463,17 +555,17 @@ namespace xla { pieces.push_back(" }"); } pieces.push_back("\n}"); - } else if (ShapeUtil::Rank(shape) == 4) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 4) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( tensorflow::strings::Printf(" { // i1=%lld\n", i1)); - for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back(" {"); - for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) { + for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(element_to_string({i0, i1, i2, i3})); } pieces.push_back("},\n"); @@ -483,20 +575,20 @@ namespace xla { pieces.push_back(" },\n"); } pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape) == 5) { - pieces.push_back(ShapeUtil::HumanString(shape)); + } else if (ShapeUtil::Rank(shape()) == 5) { + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) { + for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); - for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) { + for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( tensorflow::strings::Printf(" { // i1=%lld\n", i1)); - for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) { + for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back( tensorflow::strings::Printf(" { // i2=%lld\n", i2)); - for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) { + for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(" {"); - for (int64 i4 = 0; i4 < shape.dimensions(4); ++i4) { + for (int64 i4 = 0; i4 < shape().dimensions(4); ++i4) { pieces.push_back(element_to_string({i0, i1, i2, i3, i4})); } pieces.push_back("},\n"); @@ -509,14 +601,14 @@ namespace xla { } pieces.push_back("}"); } else { - pieces.push_back(ShapeUtil::HumanString(shape)); + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {...}"); } return tensorflow::str_util::Join(pieces, ""); } -/* static */ std::unique_ptr LiteralUtil::MakeTuple( +/* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { auto literal = MakeUnique(); std::vector shape; @@ -528,169 +620,197 @@ namespace xla { return literal; } -/* static */ const void* LiteralUtil::InternalData(const Literal& literal) { - switch (literal.shape().element_type()) { +const void* Literal::InternalData() const { + return const_cast( + const_cast(this)->MutableInternalData()); +} + +void* Literal::MutableInternalData() { + // NOTE: We access the vectors directly to avoid the const reference + // created by the accessor functions. + switch (shape().element_type()) { case PRED: - return reinterpret_cast(literal.preds().data()); + return reinterpret_cast(preds_.data()); case U8: - return reinterpret_cast(literal.u8s().data()); + return reinterpret_cast(u8s_.data()); case S32: - return reinterpret_cast(literal.s32s().data()); + return reinterpret_cast(s32s_.data()); case S64: - return reinterpret_cast(literal.s64s().data()); + return reinterpret_cast(s64s_.data()); case U32: - return reinterpret_cast(literal.u32s().data()); + return reinterpret_cast(u32s_.data()); case U64: - return reinterpret_cast(literal.u64s().data()); + return reinterpret_cast(u64s_.data()); case F32: - return reinterpret_cast(literal.f32s().data()); + return reinterpret_cast(f32s_.data()); case F64: - return reinterpret_cast(literal.f64s().data()); + return reinterpret_cast(f64s_.data()); + case F16: + return reinterpret_cast(f16s_.data()); default: LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(literal.shape().element_type()); + << PrimitiveType_Name(shape().element_type()); } } -/* static */ void* LiteralUtil::MutableInternalData(Literal* literal) { - return const_cast(LiteralUtil::InternalData(*literal)); -} - -/* static */ void LiteralUtil::Reserve(int64 num_elements, Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - switch (literal->shape().element_type()) { +void Literal::Reserve(int64 num_elements) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + switch (shape().element_type()) { case PRED: - GetMutableRepeatedField(literal)->Resize(num_elements, false); + Resize(num_elements, false); + break; + case S8: + Resize(num_elements, 0); break; case U8: - // u8s is an optional "bytes", rather than a repeated field. Therefore its - // access methods are somewhat different from the others. - literal->mutable_u8s()->resize(num_elements, 0); + Resize(num_elements, 0); break; case S32: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0); + Resize(num_elements, 0); break; case S64: - GetMutableRepeatedField(literal)->Resize( - num_elements, - /*value=*/0); + Resize(num_elements, 0); break; case U32: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0); + Resize(num_elements, 0); break; case U64: - GetMutableRepeatedField(literal)->Resize( - num_elements, - /*value=*/0); + Resize(num_elements, 0); break; case F32: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0.0f); + Resize(num_elements, 0); break; case F64: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0.0); + Resize(num_elements, 0); + break; + case F16: + Resize(num_elements, static_cast(0.0f)); break; default: LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(literal->shape().element_type()); + << PrimitiveType_Name(shape().element_type()); } } -/* static */ tensorflow::Status LiteralUtil::ValidateLiteral( - const Literal& literal) { - TF_CHECK_OK(ShapeUtil::ValidateShape(literal.shape())); - int64 expected = ShapeUtil::ElementsIn(literal.shape()); +tensorflow::Status Literal::ValidateLiteral() const { + TF_CHECK_OK(ShapeUtil::ValidateShape(shape())); + int64 expected = ShapeUtil::ElementsIn(shape()); int64 actual = -1; - switch (literal.shape().element_type()) { + switch (shape().element_type()) { case PRED: - actual = literal.preds().size(); + actual = preds_size(); break; case U8: - actual = literal.u8s().size(); + actual = u8s_size(); break; case S32: - actual = literal.s32s_size(); + actual = s32s_size(); break; case U32: - actual = literal.u32s_size(); + actual = u32s_size(); break; case S64: - actual = literal.s64s_size(); + actual = s64s_size(); break; case U64: - actual = literal.u64s_size(); + actual = u64s_size(); break; case F32: - actual = literal.f32s_size(); + actual = f32s_size(); break; case F64: - actual = literal.f64s_size(); + actual = f64s_size(); + break; + case F16: + actual = f16s().size() / sizeof(half); break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + - PrimitiveType_Name(literal.shape().element_type())); + PrimitiveType_Name(shape().element_type())); } if (expected != actual) { return tensorflow::errors::InvalidArgument(tensorflow::strings::Printf( "literal has bad number of elements for its shape %s: want %lld " "got %lld", - ShapeUtil::HumanString(literal.shape()).c_str(), expected, actual)); + ShapeUtil::HumanString(shape()).c_str(), expected, actual)); } return tensorflow::Status::OK(); } -/* static */ void LiteralUtil::EachCellAsString( - const Literal& literal, - std::function indices, - const string& value)> - per_cell) { - if (ShapeUtil::Rank(literal.shape()) == 1) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - per_cell({i0}, GetAsString(literal, {i0})); - } +void Literal::EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const { + if (ShapeUtil::HasZeroElements(shape())) { return; } + std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); +} - if (ShapeUtil::Rank(literal.shape()) == 2) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - per_cell({i0, i1}, GetAsString(literal, {i0, i1})); - } - } - return; +namespace { +template +std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { + CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); + return LiteralUtil::Convert< + typename primitive_util::PrimitiveTypeToNative::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); +} + +template +StatusOr> ConvertIfDestTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (primitive_dest_type) { +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch(src_literal); + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) +#undef CONVERT_IF_TYPES_MATCH + // Other types are not yet supported. + default: + return tensorflow::errors::InvalidArgument( + "Unimplemented: ConvertIfDestTypeMatches for type " + + PrimitiveType_Name(src_literal.shape().element_type())); } +} +} - if (ShapeUtil::Rank(literal.shape()) == 3) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - per_cell({i0, i1, i2}, GetAsString(literal, {i0, i1, i2})); - } - } - } - return; +StatusOr> LiteralUtil::ConvertIfSrcTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (src_literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); + CONVERT_IF_DEST_TYPE_MATCHES(PRED) + CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S32) + CONVERT_IF_DEST_TYPE_MATCHES(S64) + CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U32) + CONVERT_IF_DEST_TYPE_MATCHES(U64) + CONVERT_IF_DEST_TYPE_MATCHES(F32) + CONVERT_IF_DEST_TYPE_MATCHES(F64) +#undef CONVERT_IF_DEST_TYPE_MATCHES + // Other types are not yet supported. + default: + return tensorflow::errors::InvalidArgument( + "Unimplemented: ConvertIfSrcTypeMatches for type " + + PrimitiveType_Name(src_literal.shape().element_type())); } - - if (ShapeUtil::Rank(literal.shape()) == 4) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - for (int64 i3 = 0; i3 < literal.shape().dimensions(3); ++i3) { - per_cell({i0, i1, i2, i3}, GetAsString(literal, {i0, i1, i2, i3})); - } - } - } - } - return; - } - - LOG(FATAL) << "unhandled rank: " << ShapeUtil::Rank(literal.shape()); } namespace { @@ -704,8 +824,8 @@ template bool EqualElements(const Literal& literal1, const Literal& literal2, int dimension, std::vector* multi_index) { if (dimension == ShapeUtil::Rank(literal1.shape())) { - return (LiteralUtil::Get(literal1, *multi_index) == - LiteralUtil::Get(literal2, *multi_index)); + return (literal1.Get(*multi_index) == + literal2.Get(*multi_index)); } for (int64 i = 0; i < literal1.shape().dimensions(dimension); ++i) { (*multi_index)[dimension] = i; @@ -719,145 +839,197 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, } // namespace -/* static */ bool LiteralUtil::Equal(const Literal& literal1, - const Literal& literal2) { - if (!ShapeUtil::Compatible(literal1.shape(), literal2.shape())) { +bool Literal::Equal(const Literal& literal2) const { + if (!ShapeUtil::Compatible(shape(), literal2.shape())) { return false; } - if (ShapeUtil::IsTuple(literal1.shape())) { + if (ShapeUtil::IsTuple(shape())) { // Because the shapes are compatible, they must have the same number of // tuple elements. - CHECK_EQ(literal1.tuple_literals_size(), literal2.tuple_literals_size()); - for (int i = 0; i < literal1.tuple_literals_size(); ++i) { - if (!Equal(literal1.tuple_literals(i), literal2.tuple_literals(i))) { + CHECK_EQ(tuple_literals_size(), literal2.tuple_literals_size()); + for (int i = 0; i < tuple_literals_size(); ++i) { + if (!tuple_literals(i).Equal(literal2.tuple_literals(i))) { return false; } } return true; } else { - std::vector multi_index(ShapeUtil::Rank(literal1.shape()), 0); - switch (literal1.shape().element_type()) { + std::vector multi_index(ShapeUtil::Rank(shape()), 0); + switch (shape().element_type()) { case PRED: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case U8: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case S32: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case S64: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case U32: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case U64: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case F32: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); case F64: - return EqualElements(literal1, literal2, 0, &multi_index); + return EqualElements(*this, literal2, 0, &multi_index); + case F16: + return EqualElements(*this, literal2, 0, &multi_index); default: - LOG(FATAL) << "Unimplemented: LiteralUtil::Equal for type " - << PrimitiveType_Name(literal1.shape().element_type()); + LOG(FATAL) << "Unimplemented: Literal::Equal for type " + << PrimitiveType_Name(shape().element_type()); } } } template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal) { - CHECK(literal.shape().element_type() == PRED); - return literal.preds(); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_preds(); + return tensorflow::gtl::MutableArraySlice(values->data(), + values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == PRED); - return literal->mutable_preds(); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + auto values = mutable_u8s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == U32); - return literal.u32s(); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + auto values = mutable_u8s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == U32); - return literal->mutable_u32s(); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_s32s(); + return tensorflow::gtl::MutableArraySlice(values->data(), + values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == U64); - return AsUInt64Slice(literal.u64s()); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_u32s(); + return tensorflow::gtl::MutableArraySlice(values->data(), + values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal) { - CHECK(literal->shape().element_type() == U64); - return literal->mutable_u64s(); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + static_assert(sizeof(int64) == sizeof(tensorflow::protobuf_int64) && + alignof(int64) == alignof(tensorflow::protobuf_int64), + "The int64 and tensorflow::protobuf_int64 types are not " + "compatible"); + auto values = mutable_s64s(); + // Because of the fact that tensorflow::protobuf_int64 is defined as int64_t + // while tensorflow::int64 is defined as long long, a reinterpret_cast<> is + // necessary from the raw data pointer returned by the mutable_data() API. + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(values->data()), values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == S32); - return literal.s32s(); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + static_assert(sizeof(uint64) == sizeof(tensorflow::protobuf_uint64) && + alignof(uint64) == alignof(tensorflow::protobuf_uint64), + "The uint64 and tensorflow::protobuf_uint64 types are not " + "compatible"); + auto values = mutable_u64s(); + // Because of the fact that tensorflow::protobuf_uint64 is defined as uint64_t + // while tensorflow::uint64 is defined as unsigned long long, a + // reinterpret_cast<> is necessary from the raw data pointer returned by the + // mutable_data() API. + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(values->data()), values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == S32); - return literal->mutable_s32s(); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_f32s(); + return tensorflow::gtl::MutableArraySlice(values->data(), + values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == S64); - return AsInt64Slice(literal.s64s()); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_f64s(); + return tensorflow::gtl::MutableArraySlice(values->data(), + values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal) { - CHECK(literal->shape().element_type() == S64); - return literal->mutable_s64s(); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + // TODO - there is an endianess problem here. fix it, or wait for uint16 + // support in protobuf + auto values = mutable_f16s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), values->size() / sizeof(half)); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == F32); - return literal.f32s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), PRED); + return tensorflow::gtl::ArraySlice(preds().data(), preds().size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == F32); - return literal->mutable_f32s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), U8); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(u8s().data()), u8s().size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == F64); - return literal.f64s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), S8); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(u8s().data()), u8s().size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == F64); - return literal->mutable_f64s(); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), U32); + return u32s(); +} + +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), U64); + return u64s(); +} + +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), S32); + return s32s(); +} + +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), S64); + return s64s(); +} + +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), F64); + return f64s(); +} + +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), F16); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(f16s().data()), + f16s().size() / sizeof(half)); } template @@ -865,46 +1037,48 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { auto multi_index = IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - if (LiteralUtil::Get(literal, multi_index) != value) { + if (literal.Get(multi_index) != value) { return false; } } return true; } -/* static */ bool LiteralUtil::IsAll(const Literal& literal, int8 value) { - switch (literal.shape().element_type()) { +bool Literal::IsAll(int8 value) const { + switch (shape().element_type()) { case U8: if (value >= 0) { - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); } return false; case U32: if (value >= 0) { - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); } return false; case U64: if (value >= 0) { - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); } return false; case S8: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case S32: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case S64: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F32: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F64: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); + case F16: + return AllElementsEqualValue(*this, static_cast(value)); case PRED: if (value == 0) { - return AllElementsEqualValue(literal, false); + return AllElementsEqualValue(*this, false); } if (value == 1) { - return AllElementsEqualValue(literal, true); + return AllElementsEqualValue(*this, true); } return false; default: @@ -912,89 +1086,219 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { } } -/* static */ bool LiteralUtil::IsAllFloat(const Literal& literal, float value) { - switch (literal.shape().element_type()) { +bool Literal::IsAllFloat(float value) const { + switch (shape().element_type()) { case F32: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); case F64: - return AllElementsEqualValue(literal, value); + return AllElementsEqualValue(*this, value); + case F16: + return AllElementsEqualValue(*this, static_cast(value)); default: return false; } } -/* static */ bool LiteralUtil::IsZero( - const Literal& literal, tensorflow::gtl::ArraySlice indices) { - switch (literal.shape().element_type()) { +bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { + switch (shape().element_type()) { case U8: - return Get(literal, indices) == 0; + return Get(indices) == 0; case U32: - return Get(literal, indices) == 0; + return Get(indices) == 0; case U64: - return Get(literal, indices) == 0; + return Get(indices) == 0; case S8: - return Get(literal, indices) == 0; + return Get(indices) == 0; case S32: - return Get(literal, indices) == 0; + return Get(indices) == 0; case S64: - return Get(literal, indices) == 0; + return Get(indices) == 0; case F32: - return Get(literal, indices) == 0.0f; + return Get(indices) == 0.0f; case F64: - return Get(literal, indices) == 0.0; + return Get(indices) == 0.0; + case F16: + return Get(indices) == static_cast(0.0f); case PRED: - return Get(literal, indices) == false; + return Get(indices) == false; default: LOG(FATAL) << "Input literal must be an array."; } } template <> -/* static */ void LiteralUtil::PopulateWithValue( - int64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), dimensions); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { - repeated_field->Add(value); +/* static */ void Literal::Resize(int64 num_elements, bool value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_preds()->resize(num_elements, value); +} + +template <> +void Literal::Resize(int64 num_elements, int8 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u8s()->resize(num_elements, value); +} + +template <> +void Literal::Resize(int64 num_elements, uint8 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u8s()->resize(num_elements, value); +} + +template <> +void Literal::Resize(int64 num_elements, int32 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_s32s()->resize(num_elements, value); +} + +template <> +void Literal::Resize(int64 num_elements, uint32 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u32s()->resize(num_elements, value); +} + +template <> +void Literal::Resize(int64 num_elements, int64 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_s64s()->resize(num_elements, value); +} + +template <> +void Literal::Resize(int64 num_elements, uint64 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_u64s()->resize(num_elements, value); +} + +template <> +void Literal::Resize(int64 num_elements, float value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_f32s()->resize(num_elements, value); +} + +template <> +void Literal::Resize(int64 num_elements, double value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_f64s()->resize(num_elements, value); +} + +template <> +void Literal::Resize(int64 num_elements, half value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_f16s()->resize(num_elements, value); +} + +template +static void CopyToRepeatedField(RepeatedFieldT* dest, + const std::vector& src) { + *dest = RepeatedFieldT(src.begin(), src.end()); +} + +template +static void CopyToRepeatedBoolField(RepeatedFieldT* dest, + const BoolVector& src) { + *dest = RepeatedFieldT(src.begin(), src.end()); +} + +LiteralProto Literal::ToProto() const { + LiteralProto proto; + proto.Clear(); + *proto.mutable_shape() = shape(); + switch (shape().element_type()) { + case PRED: + if (preds().begin()) { + CopyToRepeatedBoolField(proto.mutable_preds(), preds()); + } + break; + case U8: + *proto.mutable_u8s() = u8s_string(); + break; + case S32: + CopyToRepeatedField(proto.mutable_s32s(), s32s()); + break; + case S64: + CopyToRepeatedField(proto.mutable_s64s(), s64s()); + break; + case U32: + CopyToRepeatedField(proto.mutable_u32s(), u32s()); + break; + case U64: + CopyToRepeatedField(proto.mutable_u64s(), u64s()); + break; + case F16: + *proto.mutable_f16s() = + string(reinterpret_cast(f16s_.data()), + f16s_.size() * sizeof(half)); + break; + case F32: + CopyToRepeatedField(proto.mutable_f32s(), f32s()); + break; + case F64: + CopyToRepeatedField(proto.mutable_f64s(), f64s()); + break; + case TUPLE: + for (const auto& tuple : tuple_literals()) { + *proto.add_tuple_literals() = tuple.ToProto(); + } + break; + default: + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); } + + return proto; } -template <> -/* static */ void LiteralUtil::PopulateWithValue( - uint64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), dimensions); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { - repeated_field->Add(value); +template +static void CopyFromRepeatedField(std::vector* dest, + const RepeatedFieldT& src) { + *dest = std::vector(src.begin(), src.end()); +} + +void Literal::CopyFromProto(const LiteralProto& literal_proto) { + if (!literal_proto.has_shape()) { + return; } -} -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Resize(num_elements, value); -} - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Resize(num_elements, value); + *mutable_shape() = literal_proto.shape(); + switch (shape().element_type()) { + case PRED: + *mutable_preds() = BoolVector(literal_proto.preds().begin(), + literal_proto.preds().end()); + break; + case U8: + set_u8s(literal_proto.u8s()); + break; + case S32: + CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s()); + break; + case S64: + CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s()); + break; + case U32: + CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s()); + break; + case U64: + CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s()); + break; + case F16: { + const string& s(literal_proto.f16s()); + CHECK_EQ(0, s.size() % sizeof(half)); + f16s_ = std::vector(s.size() / sizeof(half)); + memcpy(f16s_.data(), s.data(), s.size()); + break; + } + case F32: + CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s()); + break; + case F64: + CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); + break; + case TUPLE: + for (const auto& proto : literal_proto.tuple_literals()) { + mutable_tuple_literals()->push_back(Literal(proto)); + } + break; + default: + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + } } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index db467a59113..42c8b61acec 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -47,15 +49,210 @@ limitations under the License. namespace xla { +// This class is a simple vector of boolean values. It's used to workaround some +// implementations of std::vector that use a bitset which does not have +// the semantics expected by Literal::preds(). +class BoolVector { + public: + typedef bool* iterator; + typedef const bool* const_iterator; + + BoolVector() : bits_(nullptr), size_(0), capacity_(0) {} + + BoolVector(const_iterator other_begin, const_iterator other_end) + : bits_(nullptr), size_(0), capacity_(0) { + if (other_begin && other_end) { + resize(other_end - other_begin); + memcpy(begin(), other_begin, size()); + } + } + + BoolVector(const BoolVector& other) { CopyFrom(other); } + + BoolVector& operator=(const BoolVector& other) { + CopyFrom(other); + return *this; + } + + void push_back(const bool& value) { + resize(size_ + 1); + bits_[size_ - 1] = value; + } + + bool* data() const { return bits_.get(); } + + size_t size() const { return size_; } + + size_t capacity() const { return capacity_; } + + void resize(size_t new_size, bool val = false) { + if (new_size == 0) { + bits_.reset(nullptr); + size_ = 0; + capacity_ = 0; + } else { + size_t old_size = size(); + if (new_size > old_size) { + grow(new_size); + } + if (old_size < new_size) { + memset(&bits_[old_size], val, new_size - old_size); + } + size_ = new_size; + } + } + + void clear() { + bits_.reset(nullptr); + size_ = 0; + capacity_ = 0; + } + + iterator begin() { return &bits_[0]; } + iterator end() { return &bits_[size()]; } + const_iterator begin() const { return &bits_[0]; } + const_iterator end() const { return &bits_[size()]; } + + private: + void grow(size_t n) { + if (capacity_ < n) { + capacity_ = 2 * n; + bool* new_bits = new bool[capacity_](); + if (size_ > 0) { + memcpy(new_bits, bits_.get(), size_); + } + bits_.reset(new_bits); + } + } + + void CopyFrom(const BoolVector& other) { + bits_ = MakeUnique(other.capacity()); + memcpy(begin(), other.begin(), other.size()); + size_ = other.size(); + capacity_ = other.capacity(); + } + + std::unique_ptr bits_; + size_t size_; + size_t capacity_; +}; + // Utility class for dealing with XLA literal values. Most methods are // templated by native (host) type which corresponds to a unique XLA // PrimitiveType. See ComputationBuilder for details. Not all primitive types // defined in xla_data.proto have a corresponding native type or even have a // storage location in the Literal proto yet (for example, primitive type F16). -class LiteralUtil { +class Literal { public: - // Create new literal of a given rank. To minimize ambiguity (for users and - // the compiler) these CreateR[0-2] methods should explicitly specify the + Literal() {} + + Literal(const Literal& other) = default; + + explicit Literal(const LiteralProto& other) { CopyFromProto(other); } + + Literal& operator=(const Literal& other) = default; + + LiteralProto ToProto() const; + + bool has_shape() const { + return shape_.element_type() != PRIMITIVE_TYPE_INVALID; + } + + // Basic accessor functions. Names mirror the original protobuf + // functions for convenience. + string DebugString() const { return ToProto().DebugString(); } + string ShortDebugString() const { return ToProto().ShortDebugString(); } + + void Clear() { + shape_.Clear(); + preds_.clear(); + u8s_.clear(); + s32s_.clear(); + s64s_.clear(); + u32s_.clear(); + u64s_.clear(); + f16s_.clear(); + f32s_.clear(); + f64s_.clear(); + tuple_literals_.clear(); + } + + int preds_size() const { return preds().size(); } + const BoolVector& preds() const { return preds_; } + BoolVector* mutable_preds() { return &preds_; } + + int s32s_size() const { return s32s().size(); } + int32 s32s(int i) const { return s32s_[i]; } + const std::vector& s32s() const { return s32s_; } + std::vector* mutable_s32s() { return &s32s_; } + + int s64s_size() const { return s64s().size(); } + void add_s64s(int64 value) { s64s_.push_back(value); } + const std::vector& s64s() const { return s64s_; } + std::vector* mutable_s64s() { return &s64s_; } + + int u32s_size() const { return u32s().size(); } + uint32 u32s(int i) const { return u32s_[i]; } + const std::vector& u32s() const { return u32s_; } + std::vector* mutable_u32s() { return &u32s_; } + + int u64s_size() const { return u64s().size(); } + const std::vector& u64s() const { return u64s_; } + std::vector* mutable_u64s() { return &u64s_; } + + int f16s_size() const { return f16s().size(); } + half f16s(int i) const { return f16s_[i]; } + const std::vector& f16s() const { return f16s_; } + std::vector* mutable_f16s() { return &f16s_; } + + int f32s_size() const { return f32s().size(); } + float f32s(int i) const { return f32s_[i]; } + void add_f32s(float value) { f32s_.push_back(value); } + const std::vector& f32s() const { return f32s_; } + std::vector& f32s() { return f32s_; } + std::vector* mutable_f32s() { return &f32s_; } + + int f64s_size() const { return f64s().size(); } + const std::vector& f64s() const { return f64s_; } + std::vector* mutable_f64s() { return &f64s_; } + + int tuple_literals_size() const { return tuple_literals().size(); } + const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } + Literal* add_tuple_literals() { + tuple_literals_.push_back(Literal()); + return &tuple_literals_.back(); + } + std::vector* mutable_tuple_literals() { return &tuple_literals_; } + const std::vector& tuple_literals() const { return tuple_literals_; } + + int u8s_size() const { return u8s().size(); } + const std::vector& u8s() const { return u8s_; } + void set_u8s(const std::vector& value) { u8s_ = value; } + void set_u8s(tensorflow::StringPiece value) { + u8s_ = std::vector(value.size()); + u8s_.clear(); + append_u8s(value); + } + + void append_u8s(tensorflow::StringPiece value) { + u8s_.insert(u8s_.end(), value.begin(), value.end()); + } + + string u8s_string() const { return string(u8s().begin(), u8s().end()); } + + std::vector* mutable_u8s() { return &u8s_; } + + const Shape& shape() const { return shape_; } + Shape* mutable_shape() { return &shape_; } + + void Swap(Literal* other) { + Literal temp = *this; + *this = *other; + *other = temp; + } + + // CreatesCreate new literal of a given rank. To minimize ambiguity (for users + // and the compiler) these CreateR[0-2] methods should explicitly specify the // native type. For example: // // CreateR1({1.0, 42.0}); @@ -100,75 +297,98 @@ class LiteralUtil { values, const Layout& layout); - // Creates a new value that has the equivalent value as literal, but conforms - // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major - // dimension layout can be re-layed-out as {1, 0} minor-to-major dimension - // layout and the value in the cell at any given logical index (i0, i1) will - // be the same. + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromShape(const Shape& shape); + + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + Status Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Creates a new value that has the equivalent value as this literal, but + // conforms to new_layout; e.g. a literal matrix that was in {0, 1} + // minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. // // Note: this is useful when the client wants to ensure that a value placed in // the XLA allocation tracker has a particular layout; for efficiency // purposes or avoiding unimplemented operation/layout combinations. - static std::unique_ptr Relayout(const Literal& literal, - const Layout& new_layout); + std::unique_ptr Relayout(const Layout& new_layout) const; - // Reshapes literal 'input' to have 'shape'. Both the original shape and - // 'shape' must contain the same number of elements. The implementation - // currently only supports monotonic dim0-major layouts. - static StatusOr> Reshape( - const xla::Literal& input, tensorflow::gtl::ArraySlice shape); + // Creates a new literal by reshaping this literal to have 'shape'. Both the + // original shape and 'shape' must contain the same number of elements. The + // implementation currently only supports monotonic dim0-major layouts. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice shape) const; - // Creates a new literal by reordering the dimensions of the original literal. + // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers // in the original literal, and it specifies the order of the new dimensions // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - static std::unique_ptr Transpose( - const Literal& literal, tensorflow::gtl::ArraySlice permutation); + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; - // Creates a sub-array from the the given literal by extracting the indices + // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the // same rank and layout as for the given literal. The number of indices in // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. - static std::unique_ptr Slice( - const Literal& literal, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices); + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. template - static std::unique_ptr Replicate(const Literal& input, int64 times); + std::unique_ptr Replicate(int64 times) const; - // Create a literal by converting each element in an original literal to a new + // Creates a literal by converting each element in this literal to a new // type. template - static std::unique_ptr Convert(const Literal& literal); + std::unique_ptr Convert() const; - // Create a literal value zero of the given primitive type. + // Creates a literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Create a literal value one of the given primitive type. + // Creates a literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); // Creates a literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Create a literal value containing the maximum value of the given + // Creates a literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); - // Create a literal of the given shape where each element is `value`. + // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value); - // Create a new literal from an array. The variants not ending with WithLayout - // use the default XLA layout for the literal's linear representation in - // memory. + // Creates a new literal from an array. The variants not ending with + // WithLayout use the default XLA layout for the literal's linear + // representation in memory. template static std::unique_ptr CreateR2FromArray2D( const Array2D& values); @@ -210,28 +430,33 @@ class LiteralUtil { std::initializer_list> values, int64 projection_p, int64 projection_z); - // Clones literal into an owned unique_ptr version. - static std::unique_ptr CloneToUnique(const Literal& literal); + // Clones this literal into an owned unique_ptr version. + std::unique_ptr CloneToUnique() const; + + // Returns the linear index of the given index within this literal's + // element_type repeated field. + int64 LinearIndex(tensorflow::gtl::ArraySlice multi_index) const; // Gets or sets an element in the literal at the given index. The index is // CHECKed against the dimension sizes. template - static NativeT Get(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; template - static void Set(Literal* literal, - tensorflow::gtl::ArraySlice multi_index, - NativeT value); + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Retrieves the mutable array slice interface which can be used to manipulate + // pre-allocated literal values. + template + tensorflow::gtl::MutableArraySlice GetMutableArraySlice(); // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. template - static NativeT GetFirstElement(const Literal& literal); + NativeT GetFirstElement() const; // As Get(), but determines the correct type and converts the value // into text. - static string GetAsString(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + string GetAsString(tensorflow::gtl::ArraySlice multi_index) const; // Returns an identity matrix (rank 2) with the given row and column count. template @@ -243,10 +468,530 @@ class LiteralUtil { // Validates that the data payload of the literal matches the literal shape; // if it does not, an appropriate status is returned. - static tensorflow::Status ValidateLiteral(const Literal& literal); + tensorflow::Status ValidateLiteral() const; // Returns a string representation of the literal value. - static string ToString(const Literal& literal); + string ToString() const; + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; + + // Templated methods which populate the given repeated field in this literal + // with the given value(s). The Shape field of this literal is set + // to match the array dimensions and type. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + template + void PopulateR0(NativeT values); + template + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); + template + void PopulateR2(std::initializer_list> values); + template + void PopulateR2WithLayout( + std::initializer_list> values, + const Layout& layout); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); + template + void PopulateR4FromArray4D(const Array4D& values); + template + void PopulateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); + + // Populates literal values by calling the generator function for every cell + // in this literal object. + template + Status Populate( + const std::function indexes)>& + generator); + + // Creates a Literal of the given dimensions with all elements set to the + // given value. + template + void PopulateWithValue(NativeT value, + tensorflow::gtl::ArraySlice dimensions); + + // Returns a pointer to the underlying vector corresponding to the Literal's + // shape. + const void* InternalData() const; + void* MutableInternalData(); + + // Allocates space in the underlying vector of this literal sufficient to hold + // num_elements of this literal's primitive type. Values in the vector are set + // to zero. num_elements must equal the number of elements in the literal's + // shape. + void Reserve(int64 num_elements); + + // Allocates space in the underlying vector of this literal sufficient to hold + // num_elements of this literal's primitive type and sets each element in this + // literal to the given value. num_elements must equal the number of elements + // in this literal's shape. + template + void Resize(int64 num_elements, NativeT value); + + // Returns true if this literal has the same shape and value as the given + // literal. Layout is not considered in the comparison. + bool Equal(const Literal& literal2) const; + + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. + bool IsAll(int8 value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. + bool IsAllFloat(float value) const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + private: + // Returns an ArraySlice view of the array for this literal for the given + // NativeT (e.g., float). These functions map native type to XLA PrimitiveType + // via template specialization. The unspecialized forms below aborts to handle + // the error case where the given native type does not map to an XLA primitive + // type. + template + tensorflow::gtl::ArraySlice GetArraySlice() const { + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + } + + // Copy from a LiteralProto instance. + void CopyFromProto(const LiteralProto& literal_proto); + + // Internal template helper for the Copy() API, matching its arguments one by + // one. + template + Status CopyRange(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Utility structure which is used to create the optimal configuration for + // a ShapeUtil::ForEachIndex() scan across two literals. + struct StrideConfig { + StrideConfig(const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions); + + // The dimensions of the stride operation. Essentially every dimension + // will be iterated from base[i] to base[i]+dimensions[i], in step[i] + // steps. + tensorflow::gtl::ArraySlice dimensions; + DimensionVector base; + DimensionVector step; + int64 minor_dimension = 0; + // The size of the strides for source and destination. One of the two + // (the one looping through its most minor dimension) will be 1, while + // the other will be the stride size at the dimension matching the other + // shape most minor dimension being scanned. + int64 dest_stride = 1; + int64 source_stride = 1; + // The size of the inner loop on the most minor dimension. + int64 minor_loop_size = 1; + }; + + Shape shape_; + BoolVector preds_; + std::vector u8s_; + std::vector s32s_; + std::vector s64s_; + std::vector u32s_; + std::vector u64s_; + std::vector f16s_; + std::vector f32s_; + std::vector f64s_; + std::vector tuple_literals_; +}; + +// Utility class for dealing with XLA literal values. Most methods are +// templated by native (host) type which corresponds to a unique XLA +// PrimitiveType. See ComputationBuilder for details. Not all primitive types +// defined in xla_data.proto have a corresponding native type or even have a +// storage location in the Literal proto yet (for example, primitive type F16). +// +// TODO(dnovillo) - All functions in this class simply redirect to the +// corresponding function in class Literal. Remove this class after converting +// all user code to use Literal directly. +class LiteralUtil { + public: + // Creates new literal of a given rank. To minimize ambiguity (for users and + // the compiler) these CreateR[0-2] methods should explicitly specify the + // native type. For example: + // + // CreateR1({1.0, 42.0}); + // CreateR2({{1, 2}, {3, 4}}); + // + // The variants not ending with WithLayout use the default XLA layout for the + // literal's linear representation in memory. + template + static std::unique_ptr CreateR0(NativeT value) { + return Literal::CreateR0(value); + } + + template + static std::unique_ptr CreateR1( + tensorflow::gtl::ArraySlice values) { + return Literal::CreateR1(values); + } + + static std::unique_ptr CreateR1( + const tensorflow::core::Bitmap& values) { + return Literal::CreateR1(values); + } + + template + static std::unique_ptr CreateR2( + std::initializer_list> values) { + return Literal::CreateR2(values); + } + + template + static std::unique_ptr CreateR2WithLayout( + std::initializer_list> values, + const Layout& layout) { + return Literal::CreateR2WithLayout(values, layout); + } + + template + static std::unique_ptr CreateR3( + std::initializer_list< + std::initializer_list>> + values) { + return Literal::CreateR3(values); + } + + template + static std::unique_ptr CreateR3WithLayout( + std::initializer_list< + std::initializer_list>> + values, + const Layout& layout) { + return Literal::CreateR3WithLayout(values, layout); + } + + template + static std::unique_ptr CreateR4( + std::initializer_list>>> + values) { + return Literal::CreateR4(values); + } + + template + static std::unique_ptr CreateR4WithLayout( + std::initializer_list>>> + values, + const Layout& layout) { + return Literal::CreateR4WithLayout(values, layout); + } + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromShape(const Shape& shape) { + return Literal::CreateFromShape(shape); + } + + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions) { + return Literal::CreateFromDimensions(primitive_type, dimensions); + } + + // Copies the values from src_literal, starting at src_base shape indexes, + // to dest_literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // + // The src_literal and dest_literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + static Status Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + return dest_literal->Copy(src_literal, src_base, dest_base, copy_size); + } + + // Creates a new value that has the equivalent value as literal, but conforms + // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major + // dimension layout can be re-laid-out as {1, 0} minor-to-major dimension + // layout and the value in the cell at any given logical index (i0, i1) will + // be the same. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + static std::unique_ptr Relayout(const Literal& literal, + const Layout& new_layout) { + return literal.Relayout(new_layout); + } + + // Reshapes literal 'input' to have 'shape'. Both the original shape and + // 'shape' must contain the same number of elements. The implementation + // currently only supports monotonic dim0-major layouts. + static StatusOr> Reshape( + const xla::Literal& input, tensorflow::gtl::ArraySlice shape) { + return input.Reshape(shape); + } + + // Creates a new literal by reordering the dimensions of the original literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + static std::unique_ptr Transpose( + const Literal& literal, tensorflow::gtl::ArraySlice permutation) { + return literal.Transpose(permutation); + } + + // Creates a sub-array from the given literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + static std::unique_ptr Slice( + const Literal& literal, tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) { + return literal.Slice(start_indices, limit_indices); + } + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input + // literal replicated four times. + template + static std::unique_ptr Replicate(const Literal& input, int64 times) { + return input.Replicate(times); + } + + // Creates a literal by converting each element in an original literal to a + // new type. + template + static std::unique_ptr Convert(const Literal& literal) { + return literal.Convert(); + } + + // Convert a literal to another primitive type, but only if the literal + // type is connvertable into the destination type + static StatusOr> ConvertIfSrcTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type); + + // Creates a literal value zero of the given primitive type. + static Literal Zero(PrimitiveType primitive_type) { + return Literal::Zero(primitive_type); + } + + // Creates a literal value one of the given primitive type. + static Literal One(PrimitiveType primitive_type) { + return Literal::One(primitive_type); + } + + // Creates a literal value containing the minimum value of the given + // primitive type. For floating-point types, returns -inf. + static Literal MinValue(PrimitiveType primitive_type) { + return Literal::MinValue(primitive_type); + } + + // Creates a literal value containing the maximum value of the given + // primitive type. For floating-point types, returns inf. + static Literal MaxValue(PrimitiveType primitive_type) { + return Literal::MaxValue(primitive_type); + } + + // Creates a literal of the given shape where each element is `value`. + template + static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( + tensorflow::gtl::ArraySlice dimensions, NativeT value) { + return Literal::CreateFullWithMonotonicDim0MajorLayout(dimensions, value); + } + + // Creates a new literal from an array. The variants not ending with + // WithLayout use the default XLA layout for the literal's linear + // representation in memory. + template + static std::unique_ptr CreateR2FromArray2D( + const Array2D& values) { + return Literal::CreateR2FromArray2D(values); + } + + template + static std::unique_ptr CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return Literal::CreateR2FromArray2DWithLayout(values, layout); + } + + template + static std::unique_ptr CreateR3FromArray3D( + const Array3D& values) { + return Literal::CreateR3FromArray3D(values); + } + + template + static std::unique_ptr CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { + return Literal::CreateR3FromArray3DWithLayout(values, layout); + } + + template + static std::unique_ptr CreateR4FromArray4D( + const Array4D& values) { + return Literal::CreateR4FromArray4D(values); + } + + template + static std::unique_ptr CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { + return Literal::CreateR4FromArray4DWithLayout(values, layout); + } + + // Creates a new vector of U8s literal value from a string. + static std::unique_ptr CreateR1U8(tensorflow::StringPiece value) { + return Literal::CreateR1U8(value); + } + + // Creates a linspace-populated literal with the given number of rows and + // columns. + static std::unique_ptr CreateR2F32Linspace(float from, float to, + int64 rows, int64 cols) { + return Literal::CreateR2F32Linspace(from, to, rows, cols); + } + + // Creates a literal that projects the (x, y) dimensions given in values into + // the z dimension given by "projection". + template + static std::unique_ptr CreateR3Projected( + std::initializer_list> values, + int64 projection) { + return Literal::CreateR3Projected(values, projection); + } + + // Creates a literal that projects the (x, y) dimensions given in values into + // the z and p dimensions given. + template + static std::unique_ptr CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z) { + return Literal::CreateR4Projected(values, projection_p, projection_z); + } + + // Clones literal into an owned unique_ptr version. + static std::unique_ptr CloneToUnique(const Literal& literal) { + return literal.CloneToUnique(); + } + + // Returns the linear index of the given index within the literal's + // element_type repeated field. + static int64 LinearIndex(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index) { + return literal.LinearIndex(multi_index); + } + + // Gets or sets an element in the literal at the given index. The index is + // CHECKed against the dimension sizes. + template + static NativeT Get(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index) { + return literal.Get(multi_index); + } + + template + static void Set(Literal* literal, + tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + literal->Set(multi_index, value); + } + + // Retrieves the mutable array slice interface which can be used to manipulate + // pre-allocated literal values. + template + static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( + Literal* literal) { + return literal->GetMutableArraySlice(); + } + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + static NativeT GetFirstElement(const Literal& literal) { + return literal.GetFirstElement(); + } + + // As Get(), but determines the correct type and converts the value + // into text. + static string GetAsString(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index) { + return literal.GetAsString(multi_index); + } + + // Returns an identity matrix (rank 2) with the given row and column count. + template + static std::unique_ptr MakeIdentityR2(int64 size) { + return Literal::MakeIdentityR2(size); + } + + // Returns a tuple literal composed of given literals. + static std::unique_ptr MakeTuple( + tensorflow::gtl::ArraySlice elements) { + return Literal::MakeTuple(elements); + } + + // Validates that the data payload of the literal matches the literal shape; + // if it does not, an appropriate status is returned. + static tensorflow::Status ValidateLiteral(const Literal& literal) { + return literal.ValidateLiteral(); + } + + // Returns a string representation of the literal value. + static string ToString(const Literal& literal) { return literal.ToString(); } // Invokes the "per cell" callback for each element in the provided // literal with the element's indices and a string representation of @@ -257,15 +1002,19 @@ class LiteralUtil { // like representation in a protobuf). static void EachCellAsString( const Literal& literal, - std::function indices, - const string& value)> - per_cell); + const std::function indices, + const string& value)>& per_cell) { + literal.EachCellAsString(per_cell); + } + template static void EachCell( const Literal& literal, std::function indices, NativeT value)> - per_cell); + per_cell) { + literal.EachCell(per_cell); + } // Templated methods which populate the given repeated field in the Literal // proto with the given value(s). The Shape field of the Literal proto is set @@ -279,70 +1028,125 @@ class LiteralUtil { // PopulateR2({{1, 2}, {3, 4}}, literal); // template - static void PopulateR0(NativeT values, Literal* literal); + static void PopulateR0(NativeT values, Literal* literal) { + literal->PopulateR0(values); + } + template static void PopulateR1(tensorflow::gtl::ArraySlice values, - Literal* literal); + Literal* literal) { + literal->PopulateR1(values); + } + static void PopulateR1(const tensorflow::core::Bitmap& values, - Literal* literal); + Literal* literal) { + literal->PopulateR1(values); + } + template static void PopulateR2( std::initializer_list> values, - Literal* literal); + Literal* literal) { + literal->PopulateR2(values); + } + template static void PopulateR2WithLayout( std::initializer_list> values, - const Layout& layout, Literal* literal); + const Layout& layout, Literal* literal) { + literal->PopulateR2WithLayout(values, layout); + } + template static void PopulateR2FromArray2D(const Array2D& values, - Literal* literal); + Literal* literal) { + literal->PopulateR2FromArray2D(values); + } + template static void PopulateR2FromArray2DWithLayout(const Array2D& values, const Layout& layout, - Literal* literal); + Literal* literal) { + literal->PopulateR2FromArray2DWithLayout(values, layout); + } + template static void PopulateR3FromArray3D(const Array3D& values, - Literal* literal); + Literal* literal) { + literal->PopulateR3FromArray3D(values); + } + template static void PopulateR3FromArray3DWithLayout(const Array3D& values, const Layout& layout, - Literal* literal); + Literal* literal) { + literal->PopulateR3FromArray3DWithLayout(values, layout); + } + template static void PopulateR4FromArray4D(const Array4D& values, - Literal* literal); + Literal* literal) { + literal->PopulateR4FromArray4D(values); + } + template static void PopulateR4FromArray4DWithLayout(const Array4D& values, const Layout& layout, - Literal* literal); + Literal* literal) { + literal->PopulateR4FromArray4DWithLayout(values, layout); + } + + // Populates literal values by calling the generator function for every cell + // in the literal object. + template + static Status Populate( + Literal* literal, + const std::function indexes)>& + generator) { + return literal->Populate(generator); + } // Creates a Literal of the given dimensions with all elements set to the // given value. template static void PopulateWithValue(NativeT value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal); + Literal* literal) { + return literal->PopulateWithValue(value, dimensions); + } - // Returns a pointer to the underlying buffer in the protobuf containing the - // array data. Use with care. - static const void* InternalData(const Literal& literal); - static void* MutableInternalData(Literal* literal); + // Returns a pointer to the underlying vector containing the array data. Use + // with care. + static const void* InternalData(const Literal& literal) { + return literal.InternalData(); + } - // Allocates space in the repeated_field of the literal sufficient to hold - // num_elements of the literal's primitive type. Values in the buffer are set + static void* MutableInternalData(Literal* literal) { + return literal->MutableInternalData(); + } + + // Allocates space in the underlying vector of the literal sufficient to hold + // num_elements of the literal's primitive type. Values in the vector are set // to zero. num_elements must equal the number of elements in the literals // shape. - static void Reserve(int64 num_elements, Literal* literal); + static void Reserve(int64 num_elements, Literal* literal) { + literal->Reserve(num_elements); + } - // Allocates space in the repeated_field of the literal sufficient to hold + // Allocates space in the underlying vector of the literal sufficient to hold // num_elements of the literal's primitive type and sets each element in the // literal to the given value. num_elements must equal the number of elements // in the literals shape. template - static void Resize(int64 num_elements, NativeT value, Literal* literal); + static void Resize(int64 num_elements, NativeT value, Literal* literal) { + literal->Resize(num_elements, value); + } // Returns true if the two given literals have the same shape and // values. Layout is not considered in the comparison. - static bool Equal(const Literal& literal1, const Literal& literal2); + static bool Equal(const Literal& literal1, const Literal& literal2) { + return literal1.Equal(literal2); + } // Returns whether every element in the given literal is equal to value. // @@ -353,7 +1157,9 @@ class LiteralUtil { // If value doesn't fit in literal's type, returns false. Values of 1/0 are // considered equal to true/false; other values are not considered equal to // true. - static bool IsAll(const Literal& literal, int8 value); + static bool IsAll(const Literal& literal, int8 value) { + return literal.IsAll(value); + } // Like IsAll(const Literal&, int8), except we check whether the literal is // equal to a particular floating-point number. @@ -364,137 +1170,149 @@ class LiteralUtil { // admonishments about floating-point equality checks apply. We expect you to // use this to check for values that can be expressed precisely as a float, // e.g. -0.5. - static bool IsAllFloat(const Literal& literal, float value); + static bool IsAllFloat(const Literal& literal, float value) { + return literal.IsAllFloat(value); + } // Returns whether the literal is zero at the specified index. The literal // must be an array. static bool IsZero(const Literal& literal, - tensorflow::gtl::ArraySlice indices); - - private: - // Returns an ArraySlice view of the array for the given literal for the - // given NativeT (e.g., float). These - // functions map native type to XLA PrimitiveType via template - // specialization. The unspecialized forms below aborts to handle the error - // case where the given native type does not map to an XLA primitive type. - template - static tensorflow::gtl::ArraySlice GetArraySlice( - const Literal& literal) { - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); + tensorflow::gtl::ArraySlice indices) { + return literal.IsZero(indices); } - template - static tensorflow::protobuf::RepeatedField* GetMutableRepeatedField( - Literal* literal) { - // Make the expression depend on the template parameter NativeT so - // that this compile-time error only apperas if this function is - // instantiated with some concrete type that is not specialized - // below. - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } - - // Returns the linear index of the given index within the literal's - // element_type repeated field. - static int64 LinearIndex(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); }; // Declarations of template specializations for GetArraySlice and -// GetMutableRepeatedField. The specializations map native type to XLA primitive +// GetMutableArraySlice. The specializations map native type to XLA primitive // type. template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +inline tensorflow::gtl::ArraySlice Literal::GetArraySlice() + const { + DCHECK(shape().element_type() == F32); + return f32s(); +} template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal); +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + +template <> +void Literal::Resize(int64 num_elements, bool value); + +template <> +void Literal::Resize(int64 num_elements, int8 value); + +template <> +void Literal::Resize(int64 num_elements, uint8 value); + +template <> +void Literal::Resize(int64 num_elements, int32 value); + +template <> +void Literal::Resize(int64 num_elements, uint32 value); + +template <> +void Literal::Resize(int64 num_elements, int64 value); + +template <> +void Literal::Resize(int64 num_elements, uint64 value); + +template <> +void Literal::Resize(int64 num_elements, float value); + +template <> +void Literal::Resize(int64 num_elements, double value); + +template <> +void Literal::Resize(int64 num_elements, half value); template -/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { +/* static */ std::unique_ptr Literal::CreateR0(NativeT value) { auto literal = MakeUnique(); - PopulateR0(value, literal.get()); + literal->PopulateR0(value); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ std::unique_ptr Literal::CreateR1( tensorflow::gtl::ArraySlice values) { auto literal = MakeUnique(); - PopulateR1(values, literal.get()); + literal->PopulateR1(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( +/* static */ std::unique_ptr Literal::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR2WithLayout(values, layout, literal.get()); + literal->PopulateR2WithLayout(values, layout); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2( +/* static */ std::unique_ptr Literal::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( +/* static */ std::unique_ptr Literal::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -519,14 +1337,14 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR3( +/* static */ std::unique_ptr Literal::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( +/* static */ std::unique_ptr Literal::CreateR4WithLayout( std::initializer_list>>> values, @@ -557,7 +1375,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4( +/* static */ std::unique_ptr Literal::CreateR4( std::initializer_list>>> values) { @@ -565,38 +1383,37 @@ template } template -/* static */ std::unique_ptr -LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR2FromArray2DWithLayout(values, layout, literal.get()); + literal->PopulateR2FromArray2DWithLayout(values, layout); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( +/* static */ std::unique_ptr Literal::CreateR2FromArray2D( const Array2D& values) { return CreateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } + template -/* static */ std::unique_ptr -LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR3FromArray3DWithLayout(values, layout, literal.get()); + literal->PopulateR3FromArray3DWithLayout(values, layout); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( +/* static */ std::unique_ptr Literal::CreateR3FromArray3D( const Array3D& values) { return CreateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( +/* static */ std::unique_ptr Literal::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -621,7 +1438,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( +/* static */ std::unique_ptr Literal::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -649,91 +1466,92 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( +/* static */ std::unique_ptr Literal::CreateR4FromArray4D( const Array4D& values) { return CreateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { auto literal = MakeUnique(); - PopulateR4FromArray4DWithLayout(values, layout, literal.get()); + literal->PopulateR4FromArray4DWithLayout(values, layout); return literal; } template -/* static */ NativeT LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - int64 linear_index = LinearIndex(literal, multi_index); - return GetArraySlice(literal).at(linear_index); +NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index) const { + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice().at(linear_index); } template -/* static */ NativeT LiteralUtil::GetFirstElement(const Literal& literal) { - return GetArraySlice(literal).at(0); +NativeT Literal::GetFirstElement() const { + return GetArraySlice().at(0); } template <> -/* static */ inline uint8 LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - CHECK(literal.shape().element_type() == U8); - int64 linear_index = LinearIndex(literal, multi_index); - return literal.u8s()[linear_index]; +inline uint8 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == U8); + int64 linear_index = LinearIndex(multi_index); + return u8s()[linear_index]; } template <> -/* static */ inline int8 LiteralUtil::Get( - const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { - CHECK(literal.shape().element_type() == S8); - int64 linear_index = LinearIndex(literal, multi_index); - return literal.u8s()[linear_index]; +inline int8 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == S8); + int64 linear_index = LinearIndex(multi_index); + return u8s()[linear_index]; +} + +template <> +inline half Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == F16); + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice()[linear_index]; } template -/* static */ void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - int64 linear_index = LinearIndex(*literal, multi_index); - GetMutableRepeatedField(literal)->Set(linear_index, value); +void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + int64 linear_index = LinearIndex(multi_index); + GetMutableArraySlice().at(linear_index) = value; } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - uint8 value) { - int64 linear_index = LinearIndex(*literal, multi_index); - (*literal->mutable_u8s())[linear_index] = value; +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + uint8 value) { + int64 linear_index = LinearIndex(multi_index); + (*mutable_u8s())[linear_index] = value; } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - int8 value) { - return Set(literal, multi_index, value); +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + int8 value) { + return Set(multi_index, value); } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - int64 value) { - int64 linear_index = LinearIndex(*literal, multi_index); - (*literal->mutable_s64s())[linear_index] = value; +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + int64 value) { + int64 linear_index = LinearIndex(multi_index); + (*mutable_s64s())[linear_index] = value; } template <> -/* static */ inline void LiteralUtil::Set( - Literal* literal, tensorflow::gtl::ArraySlice multi_index, - uint64 value) { - int64 linear_index = LinearIndex(*literal, multi_index); - (*literal->mutable_u64s())[linear_index] = value; +/* static */ inline void Literal::Set( + tensorflow::gtl::ArraySlice multi_index, uint64 value) { + int64 linear_index = LinearIndex(multi_index); + (*mutable_u64s())[linear_index] = value; } // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { +/* static */ std::unique_ptr Literal::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -742,88 +1560,51 @@ template } template -/* static */ void LiteralUtil::EachCell( - const Literal& literal, +void Literal::EachCell( std::function indices, NativeT value)> - per_cell) { - if (ShapeUtil::HasZeroElements(literal.shape())) { + per_cell) const { + if (ShapeUtil::HasZeroElements(shape())) { return; } - std::vector indices(ShapeUtil::Rank(literal.shape()), 0); + std::vector indices(ShapeUtil::Rank(shape()), 0); do { - per_cell(indices, Get(literal, indices)); - } while (IndexUtil::BumpIndices(literal.shape(), &indices)); + per_cell(indices, Get(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); } template -/* static */ void LiteralUtil::PopulateR0(NativeT value, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( +inline void Literal::PopulateR0(NativeT value) { + *mutable_shape() = ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {}); - tensorflow::protobuf::RepeatedField* repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(uint8 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(int8 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(uint64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u64s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(int64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_s64s()->Add(value); + Resize(1, value); } template -/* static */ void LiteralUtil::PopulateR1( - tensorflow::gtl::ArraySlice values, Literal* literal) { - *literal->mutable_shape() = +inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { + *mutable_shape() = ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())}); - Reserve(values.size(), literal); + Reserve(values.size()); for (int64 i = 0; i < values.size(); ++i) { - Set(literal, {i}, values[i]); + Set({i}, values[i]); } } -/* static */ inline void LiteralUtil::PopulateR1( - const tensorflow::core::Bitmap& values, Literal* literal) { - *literal->mutable_shape() = +inline void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { + *mutable_shape() = ShapeUtil::MakeShape(PRED, {static_cast(values.bits())}); - Reserve(values.bits(), literal); - for (int64 i = 0; i < values.bits(); ++i) { - Set(literal, {i}, values.get(i)); + Reserve(values.bits()); + for (int64 i = 0; i < static_cast(values.bits()); ++i) { + Set({i}, values.get(i)); } } template -/* static */ void LiteralUtil::PopulateR2WithLayout( +void Literal::PopulateR2WithLayout( std::initializer_list> values, - const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, @@ -831,17 +1612,17 @@ template const int64 dim0_size = values.size(); const int64 dim1_size = values.begin()->size(); - CHECK_EQ(dim0_size, literal->shape().dimensions(0)); - CHECK_EQ(dim1_size, literal->shape().dimensions(1)); + CHECK_EQ(dim0_size, shape().dimensions(0)); + CHECK_EQ(dim1_size, shape().dimensions(1)); const int64 num_elements = dim1_size * dim0_size; - Reserve(num_elements, literal); + Reserve(num_elements); int64 dim0 = 0; for (auto inner_list : values) { int64 dim1 = 0; for (auto value : inner_list) { - Set(literal, {dim0, dim1}, value); + Set({dim0, dim1}, value); ++dim1; } CHECK_EQ(dim1_size, dim1); @@ -850,84 +1631,79 @@ template } template -/* static */ void LiteralUtil::PopulateR2( - std::initializer_list> values, - Literal* literal) { - PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2(), literal); +void Literal::PopulateR2( + std::initializer_list> values) { + PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ void LiteralUtil::PopulateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( +void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); const int64 dim1_size = values.width(); const int64 dim0_size = values.height(); - CHECK_EQ(dim0_size, literal->shape().dimensions(0)); - CHECK_EQ(dim1_size, literal->shape().dimensions(1)); - Reserve(dim1_size * dim0_size, literal); + CHECK_EQ(dim0_size, shape().dimensions(0)); + CHECK_EQ(dim1_size, shape().dimensions(1)); + Reserve(dim1_size * dim0_size); for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { - Set(literal, {dim0, dim1}, values(dim0, dim1)); + Set({dim0, dim1}, values(dim0, dim1)); } } } template -/* static */ void LiteralUtil::PopulateR2FromArray2D( - const Array2D& values, Literal* literal) { - PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2(), - literal); +void Literal::PopulateR2FromArray2D(const Array2D& values) { + PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } + template -/* static */ void LiteralUtil::PopulateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( +void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {values.n1(), values.n2(), values.n3()}, AsInt64Slice(layout.minor_to_major())); - CHECK_EQ(values.n1(), literal->shape().dimensions(0)); - CHECK_EQ(values.n2(), literal->shape().dimensions(1)); - CHECK_EQ(values.n3(), literal->shape().dimensions(2)); - Reserve(values.n1() * values.n2() * values.n3(), literal); + CHECK_EQ(values.n1(), shape().dimensions(0)); + CHECK_EQ(values.n2(), shape().dimensions(1)); + CHECK_EQ(values.n3(), shape().dimensions(2)); + Reserve(values.n1() * values.n2() * values.n3()); for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - Set(literal, {dim0, dim1, dim2}, values(dim0, dim1, dim2)); + Set({dim0, dim1, dim2}, values(dim0, dim1, dim2)); } } } } template -/* static */ void LiteralUtil::PopulateR3FromArray3D( - const Array3D& values, Literal* literal) { - PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3(), - literal); +void Literal::PopulateR3FromArray3D(const Array3D& values) { + PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ void LiteralUtil::PopulateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout( +void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout) { + *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {values.planes(), values.depth(), values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); - CHECK_EQ(values.n1(), literal->shape().dimensions(0)); - CHECK_EQ(values.n2(), literal->shape().dimensions(1)); - CHECK_EQ(values.n3(), literal->shape().dimensions(2)); - CHECK_EQ(values.n4(), literal->shape().dimensions(3)); - Reserve(values.n1() * values.n2() * values.n3() * values.n4(), literal); + CHECK_EQ(values.n1(), shape().dimensions(0)); + CHECK_EQ(values.n2(), shape().dimensions(1)); + CHECK_EQ(values.n3(), shape().dimensions(2)); + CHECK_EQ(values.n4(), shape().dimensions(3)); + Reserve(values.n1() * values.n2() * values.n3() * values.n4()); for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { - Set(literal, {dim0, dim1, dim2, dim3}, - values(dim0, dim1, dim2, dim3)); + Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3)); } } } @@ -935,106 +1711,124 @@ template } template -/* static */ void LiteralUtil::PopulateR4FromArray4D( - const Array4D& values, Literal* literal) { - PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4(), - literal); +void Literal::PopulateR4FromArray4D(const Array4D& values) { + PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); } template -/* static */ void LiteralUtil::PopulateWithValue( - NativeT value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), dimensions); - tensorflow::protobuf::RepeatedField* repeated_field = - GetMutableRepeatedField(literal); - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { - repeated_field->Add(value); +Status Literal::Populate( + const std::function indexes)>& + generator) { + const Shape& this_shape = shape(); + int64 rank = ShapeUtil::Rank(this_shape); + TF_RET_CHECK(this_shape.element_type() == + primitive_util::NativeToPrimitiveType()); + tensorflow::gtl::MutableArraySlice data = + GetMutableArraySlice(); + if (rank > 0) { + StrideConfig stride_config(this_shape, this_shape, + AsInt64Slice(this_shape.dimensions())); + DimensionVector minor_scan_indexes(rank, 0); + int64 minor_dimension_size = + ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); + + auto init_function = [&](const std::vector& indexes) { + int64 index = LinearIndex(indexes); + std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); + for (int64 i = 0; i < minor_dimension_size; ++i) { + minor_scan_indexes[stride_config.minor_dimension] = i; + data.at(index + i) = generator(minor_scan_indexes); + } + return true; + }; + ShapeUtil::ForEachIndex(this_shape, stride_config.base, + stride_config.dimensions, stride_config.step, + init_function); + } else { + // For scalars. + data.at(0) = generator({}); } + return Status::OK(); } -template <> -/* static */ void LiteralUtil::PopulateWithValue( - int64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal); - -template <> -/* static */ void LiteralUtil::PopulateWithValue( - uint64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal); +template +void Literal::PopulateWithValue(NativeT value, + tensorflow::gtl::ArraySlice dimensions) { + *mutable_shape() = ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), dimensions); + Resize(ShapeUtil::ElementsIn(shape()), value); +} template -/* static */ std::unique_ptr LiteralUtil::Convert( - const Literal& literal) { +std::unique_ptr Literal::Convert() const { + const Shape& this_shape = shape(); auto result_literal = MakeUnique(); - Shape result_shape = literal.shape(); - result_shape.set_element_type( + Shape* result_shape = result_literal->mutable_shape(); + *result_shape = this_shape; + result_shape->set_element_type( primitive_util::NativeToPrimitiveType()); - *result_literal->mutable_shape() = result_shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(result_shape), - result_literal.get()); - LiteralUtil::EachCell( - literal, - [&](tensorflow::gtl::ArraySlice indices, NativeSrcT value) { - LiteralUtil::Set(result_literal.get(), indices, - static_cast(value)); - }); + result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); + tensorflow::gtl::ArraySlice src_data = + GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->GetMutableArraySlice(); + int64 num_elements = ShapeUtil::ElementsIn(this_shape); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = static_cast(src_data[i]); + } return result_literal; } -template -/* static */ void LiteralUtil::Resize(int64 num_elements, NativeT value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - tensorflow::protobuf::RepeatedField* repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Resize(num_elements, value); -} - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal); - template /* static */ std::unique_ptr -LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( +Literal::CreateFullWithMonotonicDim0MajorLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { - Shape shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + Shape this_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( primitive_util::NativeToPrimitiveType(), dimensions); auto literal = MakeUnique(); - *literal->mutable_shape() = shape; - Reserve(ShapeUtil::ElementsIn(shape), literal.get()); + *literal->mutable_shape() = this_shape; + literal->Reserve(ShapeUtil::ElementsIn(this_shape)); std::vector index(dimensions.size(), 0); do { - Set(literal.get(), index, value); - } while (IndexUtil::BumpIndices(shape, &index)); + literal->Set(index, value); + } while (IndexUtil::BumpIndices(this_shape, &index)); return literal; } template -/* static */ std::unique_ptr LiteralUtil::Replicate( - const Literal& input, int64 times) { - std::vector bounds = {times}; - bounds.insert(bounds.end(), input.shape().dimensions().begin(), - input.shape().dimensions().end()); +std::unique_ptr Literal::Replicate(int64 times) const { + DimensionVector bounds = {times}; + bounds.reserve(shape().dimensions_size() + 1); + for (int64 bound : shape().dimensions()) { + bounds.push_back(bound); + } auto literal = MakeUnique(); *literal->mutable_shape() = - ShapeUtil::MakeShape(input.shape().element_type(), bounds); - Reserve(ShapeUtil::ElementsIn(literal->shape()), literal.get()); - for (int64 index = 0; index < ShapeUtil::ElementsIn(input.shape()); ++index) { - const std::vector element_indices = - IndexUtil::LinearIndexToMultidimensionalIndex(input.shape(), index); - const auto element = Get(input, element_indices); - for (int64 sample = 0; sample < times; ++sample) { - std::vector output_indices = {sample}; - output_indices.insert(output_indices.end(), element_indices.begin(), - element_indices.end()); - Set(literal.get(), output_indices, element); + ShapeUtil::MakeShape(shape().element_type(), bounds); + int64 elements = ShapeUtil::ElementsIn(literal->shape()); + if (elements == 0) { + return literal; + } + literal->Reserve(elements); + + DimensionVector output_indices(bounds.size(), 0); + tensorflow::gtl::ArraySlice input_indices = output_indices; + input_indices.remove_prefix(1); + + bool done = false; + while (!done) { + const auto element = Get(input_indices); + literal->Set(output_indices, element); + + done = true; + for (int n = 0; n < output_indices.size(); ++n) { + ++output_indices[n]; + if (output_indices[n] < bounds[n]) { + done = false; + break; + } + output_indices[n] = 0; } } return literal; diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index e4adb5df56a..8d4a75d7aff 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -21,14 +21,17 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { +using ::testing::ElementsAre; + class LiteralUtilTest : public ::testing::Test { protected: LiteralUtilTest() { @@ -101,6 +104,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f32_lit = LiteralUtil::CreateR0(3.14f); ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit)); + + auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", LiteralUtil::ToString(*f16_lit)); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -159,9 +165,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { // clang-format on auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); - EXPECT_MATCH(testing::PBToVec( - literal->shape().dimensions()), - testing::VectorMatcher({2, 3, 2})); + EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); string result = LiteralUtil::ToString(*literal); const string expected = R"(f32[2,3,2] { { { 1, 2 }, @@ -182,9 +186,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on - EXPECT_MATCH( - testing::PBToVec(literal->shape().dimensions()), - testing::VectorMatcher({1, 2, 3, 2})); + EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); string result = LiteralUtil::ToString(*literal); const string expected = R"(f32[1,2,3,2] { { // i0=0 @@ -204,10 +206,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { } TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { - EXPECT_MATCH( - testing::PBToVec( - literal_r4_2x2x3x3_dim0major_->shape().dimensions()), - testing::VectorMatcher({2, 2, 3, 3})); + EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), + ElementsAre(2, 2, 3, 3)); string result = LiteralUtil::ToString(*literal_r4_2x2x3x3_dim0major_); const string expected = R"(f32[2,2,3,3] { { // i0=0 @@ -375,6 +375,15 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE( LiteralUtil::IsAll(*LiteralUtil::CreateR2({{9, 8}, {8, 8}}), 8)); + half h8(8.0f); + half h9(9.0f); + EXPECT_TRUE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h8}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h9}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h9}, {h8}}), 8)); + auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(LiteralUtil::IsAll( *LiteralUtil::CreateR2( @@ -471,6 +480,26 @@ TEST_F(LiteralUtilTest, ReshapeR4) { EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); } +TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { + // clang-format off + // F32[1x3x2x4] + auto original = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0minor_); + // F32[1x3x4x2] + auto expected = LiteralUtil::CreateR3WithLayout({ + {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, + {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, + {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, + }, layout_r3_dim0major_); + // clang-format on + auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + + EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); +} + TEST_F(LiteralUtilTest, TransposeR0) { auto original = LiteralUtil::CreateR0(1.7f); auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{}); @@ -516,27 +545,23 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) { auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); EXPECT_EQ(mat_dim0minor->s32s_size(), 6); - EXPECT_MATCH(testing::PBToVec(mat_dim0minor->s32s()), - testing::VectorMatcher({1, 4, 2, 5, 3, 6})); + EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. auto relaid_mat_to_dim0major = LiteralUtil::Relayout(*mat_dim0minor, layout_r2_dim0major_); - EXPECT_MATCH(testing::PBToVec(relaid_mat_to_dim0major->s32s()), - testing::VectorMatcher({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); EXPECT_EQ(mat_dim0major->s32s_size(), 6); - EXPECT_MATCH(testing::PBToVec(mat_dim0major->s32s()), - testing::VectorMatcher({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. auto relaid_mat_to_dim0minor = LiteralUtil::Relayout(*mat_dim0major, layout_r2_dim0minor_); - EXPECT_MATCH(testing::PBToVec(relaid_mat_to_dim0minor->s32s()), - testing::VectorMatcher({1, 4, 2, 5, 3, 6})); + EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); } TEST_F(LiteralUtilTest, TestR3LinearLayout) { @@ -558,28 +583,28 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { EXPECT_EQ(lit_dim0minor->s32s_size(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; - EXPECT_MATCH(testing::PBToVec(lit_dim0minor->s32s()), - testing::VectorMatcher(expected_dim0minor)); + EXPECT_THAT(lit_dim0minor->s32s(), + testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. auto relaid_lit_to_dim0major = LiteralUtil::Relayout(*lit_dim0minor, layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - EXPECT_MATCH(testing::PBToVec(relaid_lit_to_dim0major->s32s()), - testing::VectorMatcher(expected_dim0major)); + EXPECT_THAT(relaid_lit_to_dim0major->s32s(), + testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0major_); EXPECT_EQ(lit_dim0major->s32s_size(), 12); - EXPECT_MATCH(testing::PBToVec(lit_dim0major->s32s()), - testing::VectorMatcher(expected_dim0major)); + EXPECT_THAT(lit_dim0major->s32s(), + testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. auto relaid_lit_to_dim0minor = LiteralUtil::Relayout(*lit_dim0major, layout_r3_dim0minor_); - EXPECT_MATCH(testing::PBToVec(relaid_lit_to_dim0minor->s32s()), - testing::VectorMatcher(expected_dim0minor)); + EXPECT_THAT(relaid_lit_to_dim0minor->s32s(), + testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { @@ -645,5 +670,358 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); } +TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { + Literal output; + half h(0.25f); + LiteralUtil::PopulateWithValue(h, {}, &output); + auto expected = LiteralUtil::CreateR0(h); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { + Literal output; + half h(0.5f); + LiteralUtil::PopulateWithValue(h, {3}, &output); + auto expected = LiteralUtil::CreateR1({h, h, h}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { + Literal output; + half h(2.0f); + LiteralUtil::PopulateWithValue(h, {2, 2}, &output); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, ReplicateR2U32) { + auto input = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto output = LiteralUtil::Replicate(*input, 3); + auto expected = LiteralUtil::CreateR3( + {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); + EXPECT_TRUE(LiteralUtil::Equal(*output, *expected)); +} + +TEST_F(LiteralUtilTest, Copy) { + const int64 dimensions[] = {17, 15, 34, 21}; + const int64 layouts[][4] = { + {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}}; + for (const auto& layout : layouts) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), dimensions, layout); + auto blank = LiteralUtil::CreateFromShape(shape); + auto source = LiteralUtil::CreateFromShape(shape); + const int64 zero_base[] = {0, 0, 0, 0}; + const int64 step[] = {1, 1, 1, 1}; + uint32 seqnr = 0; + auto init_proc = [&](const std::vector& indexes) { + LiteralUtil::Set(source.get(), indexes, ++seqnr); + return true; + }; + + ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, + init_proc); + + const int64 src_base[] = {3, 1, 5, 7}; + const int64 dest_base[] = {6, 4, 12, 2}; + const int64 copy_size[] = {7, 8, 11, 9}; + + TF_EXPECT_OK(LiteralUtil::Copy(*source, src_base, blank.get(), dest_base, + copy_size)); + std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); + std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); + bool matched = true; + auto check_proc = [&](const std::vector& indexes) { + std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); + std::transform(source_indexes.begin(), source_indexes.end(), src_base, + source_indexes.begin(), std::plus()); + std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); + std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, + blank_indexes.begin(), std::plus()); + auto bval = LiteralUtil::Get(*blank, blank_indexes); + matched = (bval != 0 && + bval == LiteralUtil::Get(*source, source_indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, + check_proc); + EXPECT_TRUE(matched); + } +} + +TEST_F(LiteralUtilTest, CopyScalars) { + auto zero = LiteralUtil::CreateR0(0); + auto nine = LiteralUtil::CreateR0(9); + TF_EXPECT_OK(LiteralUtil::Copy(*nine, {}, zero.get(), {}, {})); + EXPECT_TRUE(LiteralUtil::Equal(*zero, *nine)); + + auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); + TF_EXPECT_OK(LiteralUtil::Copy(*vect, {5}, zero.get(), {}, {})); + EXPECT_EQ(LiteralUtil::Get(*zero, {}), 17); + TF_EXPECT_OK(LiteralUtil::Copy(*zero, {}, vect.get(), {4}, {})); + EXPECT_EQ(LiteralUtil::Get(*vect, {4}), 17); +} + +TEST_F(LiteralUtilTest, F16) { + // Verify that the internal data views are consistent and that they + // are in little endian format + // TODO - modify if we make the data format machine endianess dependent + auto m1 = LiteralUtil::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + Literal* l1 = m1.get(); + const char* d1 = static_cast(LiteralUtil::InternalData(*l1)); + EXPECT_EQ(d1[0], 0); + EXPECT_EQ(d1[1], 0); + EXPECT_EQ(d1[2], 0); + EXPECT_EQ(d1[3], 0); + EXPECT_EQ(d1[4], 0); + EXPECT_EQ(d1[5], 0); + EXPECT_EQ(d1[6], 0); + EXPECT_EQ(d1[7], 0); + EXPECT_EQ(LiteralUtil::InternalData(*l1), + LiteralUtil::MutableInternalData(l1)); + + half h1(1.0f); + half h2(2.0f); + auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); + Literal* l2 = m2.get(); + const char* d2 = static_cast(LiteralUtil::InternalData(*l2)); + EXPECT_EQ(d2[0], 0); + EXPECT_EQ(d2[1], 0x3C); + EXPECT_EQ(d2[2], 0); + EXPECT_EQ(d2[3], 0x40); + EXPECT_EQ(d2[4], 0); + EXPECT_EQ(d2[5], 0x40); + EXPECT_EQ(d2[6], 0); + EXPECT_EQ(d2[7], 0x3C); + EXPECT_EQ(LiteralUtil::InternalData(*l2), + LiteralUtil::MutableInternalData(l2)); +} + +TEST_F(LiteralUtilTest, Populate) { + struct PopulateData { + std::vector dimensions; + std::vector layout; + } populate_data[] = { + {{}, {}}, + {{0}, {0}}, + {{16}, {0}}, + {{2, 0}, {1, 0}}, + {{4, 16}, {1, 0}}, + {{21, 12}, {0, 1}}, + {{6, 11, 17}, {2, 0, 1}}, + {{6, 11, 5, 17}, {3, 2, 0, 1}}, + }; + for (const auto& data : populate_data) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), data.dimensions, + data.layout); + auto literal = LiteralUtil::CreateFromShape(shape); + auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { + // Offsets from linear index just to avoid R0 literals to be initialized + // with zero. + return LiteralUtil::LinearIndex(*literal, indexes) + 17; + }; + TF_EXPECT_OK(LiteralUtil::Populate(literal.get(), generator)); + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + bool matched = true; + auto check_function = [&](const std::vector& indexes) { + auto value = LiteralUtil::Get(*literal, indexes); + matched = matched && (value == generator(indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + check_function); + EXPECT_TRUE(matched); + } +} + +TEST_F(LiteralUtilTest, ConvertR4) { + // clang-format off + auto original = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0major_); + auto expected = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0major_); + // clang-format on + auto converted = LiteralUtil::Convert(*original); + + EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted)); +} + +TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { + // clang-format off + auto s8 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); + auto s32 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); + auto u32 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); + auto s64 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); + auto u64 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); + auto pred = LiteralUtil::CreateR4WithLayout({{ + {{true, false, true, false}, {false, true, false, true}}, + {{false, true, false, true}, {true, false, true, false}}, + {{true, false, true, false}, {false, true, false, true}}, + }}, layout_r4_dim0major_); + auto int32_pred = LiteralUtil::CreateR4WithLayout({{ + {{1, 0, 1, 0}, {0, 1, 0, 1}}, + {{0, 1, 0, 1}, {1, 0, 1, 0}}, + {{1, 0, 1, 0}, {0, 1, 0, 1}}, + }}, layout_r4_dim0major_); + auto f32 = LiteralUtil::CreateR4WithLayout({{ + {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, + {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, + {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, + }}, layout_r4_dim0major_); + auto f64 = LiteralUtil::CreateR4WithLayout({{ + {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, + {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, + {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, + }}, layout_r4_dim0major_); + // clang-format on + std::unique_ptr conv; + + conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, U32).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*conv, *u32)); + + conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, S32).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + + conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, U64).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*conv, *u64)); + + conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, S64).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*conv, *s64)); + + conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, PRED).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*conv, *pred)); + + conv = LiteralUtil::ConvertIfSrcTypeMatches(*pred, S32).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*conv, *int32_pred)); + + conv = LiteralUtil::ConvertIfSrcTypeMatches(*f32, S32).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + + conv = LiteralUtil::ConvertIfSrcTypeMatches(*f64, S32).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + + conv = LiteralUtil::ConvertIfSrcTypeMatches(*s32, F32).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralUtil::Equal(*conv, *f32)); + + EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, TUPLE).status().code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, F16).status().code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, S16).status().code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, U16).status().code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST_F(LiteralUtilTest, CopyFromProto_Bool) { + LiteralProto p; + p.mutable_shape()->set_element_type(PRED); + for (int len = 0; len < 25; ++len) { + p.mutable_shape()->clear_dimensions(); + p.mutable_shape()->add_dimensions(len); + p.clear_preds(); + for (int i = 0; i < len; ++i) { + p.add_preds((i % 2) == (len % 2)); + } + + Literal literal(p); + ASSERT_EQ(len, literal.preds_size()); + int i = 0; + for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) { + EXPECT_EQ((i % 2) == (len % 2), *it); + ++i; + } + } +} + +// Note that f16 is currently stored in a byte array in little endian byte order +TEST_F(LiteralUtilTest, ToProto_f16) { + half h1(1.0f); + half h2(2.0f); + + auto m = Literal::CreateR2({{h1, h2}, {h2, h1}}); + Literal* l = m.get(); + EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); + EXPECT_EQ(4, l->f16s().size()); + EXPECT_EQ(4, l->f16s_size()); + + LiteralProto p = l->ToProto(); + EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); + EXPECT_EQ(8, p.f16s().size()); + const char* d = p.f16s().data(); + EXPECT_EQ(d[0], 0); + EXPECT_EQ(d[1], 0x3C); + EXPECT_EQ(d[2], 0); + EXPECT_EQ(d[3], 0x40); + EXPECT_EQ(d[4], 0); + EXPECT_EQ(d[5], 0x40); + EXPECT_EQ(d[6], 0); + EXPECT_EQ(d[7], 0x3C); +} + +// Note that f16 is currently stored in a byte array in little endian byte order +TEST_F(LiteralUtilTest, CopyFromProto_f16) { + half h1(1.0f); + half h2(2.0f); + + const char half_vals[8] = { + 0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C + }; + LiteralProto p; + p.mutable_shape()->set_element_type(F16); + p.mutable_shape()->clear_dimensions(); + p.mutable_shape()->add_dimensions(4); + p.clear_f16s(); + p.set_f16s(half_vals, 8); + + + Literal literal(p); + ASSERT_EQ(4, literal.f16s_size()); + ASSERT_EQ(h1, literal.f16s(0)); + ASSERT_EQ(h2, literal.f16s(1)); + ASSERT_EQ(h2, literal.f16s(2)); + ASSERT_EQ(h1, literal.f16s(3)); + + const std::vector& r = literal.f16s(); + ASSERT_EQ(4, r.size()); + ASSERT_EQ(h1, r[0]); + ASSERT_EQ(h2, r[1]); + ASSERT_EQ(h2, r[2]); + ASSERT_EQ(h1, r[3]); +} + + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index cd7c42f6e17..fed0e58e66a 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -38,7 +38,8 @@ void MetricTableReport::SetEntryName(string entry_name) { void MetricTableReport::SetShowAllEntries() { max_entries_to_show_ = std::numeric_limits::max(); - max_metric_proportion_to_show = 1.1; // more than 100% + max_entries_per_category_to_show_ = std::numeric_limits::max(); + max_metric_proportion_to_show_ = 1.1; // more than 100% } void MetricTableReport::SetShowCategoryTable() { show_category_table_ = true; } @@ -141,7 +142,7 @@ void MetricTableReport::AppendCategoryTable() { int64 categories_shown = 0; for (const auto& category : categories) { if (categories_shown >= max_entries_to_show_ || - metric_sum / expected_metric_sum_ > max_metric_proportion_to_show) { + metric_sum / expected_metric_sum_ > max_metric_proportion_to_show_) { break; } ++categories_shown; @@ -149,22 +150,21 @@ void MetricTableReport::AppendCategoryTable() { // Show the category. string text = category.category_text; - if (text == "") { + if (text.empty()) { text = "[no category]"; } tensorflow::strings::StrAppend(&text, " (", category.entries.size(), " ", entry_name_, ")"); AppendTableRow(text, category.metric_sum, metric_sum); - // Show the top few entries in the category. - const int64 kMaxToShow = 5; + // Show the top entries in the category. const char* const kIndentPrefix = " * "; - int64 entries_to_show = - std::min(kMaxToShow, category.entries.size()); - if (category.entries.size() == kMaxToShow + 1) { + int64 entries_to_show = std::min(max_entries_per_category_to_show_, + category.entries.size()); + if (category.entries.size() == entries_to_show + 1) { // May as well show the last entry on the line that would otherwise say // that there is a single entry not shown. - entries_to_show = category.entries.size(); + ++entries_to_show; } for (int64 i = 0; i < entries_to_show; ++i) { AppendLine(kIndentPrefix, MetricPercent(category.entries[i]->metric), " ", @@ -193,14 +193,14 @@ void MetricTableReport::AppendEntryTable() { int64 entries_shown = 0; for (const auto& entry : entries_) { if (entries_shown >= max_entries_to_show_ || - metric_sum / expected_metric_sum_ > max_metric_proportion_to_show) { + metric_sum / expected_metric_sum_ > max_metric_proportion_to_show_) { break; } ++entries_shown; metric_sum += entry.metric; string text = entry.text; - if (text == "") { + if (text.empty()) { text = "[no entry text]"; } AppendTableRow(text, entry.metric, metric_sum); @@ -220,7 +220,14 @@ void MetricTableReport::AppendTableRow(const string& text, const double metric, const int64 max_metric_string_size = MetricString(expected_metric_sum_).size(); string metric_string = MetricString(metric); - string padding(max_metric_string_size - metric_string.size() + 1, ' '); + + // Don't try to make a gigantic string and crash if expected_metric_sum_ is + // wrong somehow. + int64 padding_len = 1; + if (max_metric_string_size >= metric_string.size()) { + padding_len += max_metric_string_size - metric_string.size(); + } + string padding(padding_len, ' '); AppendLine(padding, metric_string, " (", MetricPercent(metric), " Σ", MetricPercent(running_metric_sum), ") ", text); } diff --git a/tensorflow/compiler/xla/metric_table_report.h b/tensorflow/compiler/xla/metric_table_report.h index e967627bff4..818fb1d3fe0 100644 --- a/tensorflow/compiler/xla/metric_table_report.h +++ b/tensorflow/compiler/xla/metric_table_report.h @@ -103,6 +103,7 @@ class MetricTableReport { private: static constexpr double kDefaultMaxMetricProportionToShow = 0.99; static constexpr int64 kDefaultMaxEntriesToShow = 100; + static constexpr int64 kDefaultMaxEntriesPerCategoryToShow = 5; // Append all parameters to the report. template @@ -162,7 +163,8 @@ class MetricTableReport { // These members control how many categories and entries to show in tables. int64 max_entries_to_show_ = kDefaultMaxEntriesToShow; - double max_metric_proportion_to_show = kDefaultMaxMetricProportionToShow; + int64 max_entries_per_category_to_show_ = kDefaultMaxEntriesPerCategoryToShow; + double max_metric_proportion_to_show_ = kDefaultMaxMetricProportionToShow; // The report that is being created. string report_; diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 21766a2a0c8..d488830a6cd 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -60,8 +60,8 @@ StatusOr> PackedLiteralReader::Read( int64 elements = ShapeUtil::ElementsIn(shape); LiteralUtil::Resize(elements, std::numeric_limits::quiet_NaN(), result.get()); - tensorflow::protobuf::RepeatedField* field = result->mutable_f32s(); - char* data = tensorflow::bit_cast(field->mutable_data()); + std::vector* field = result->mutable_f32s(); + char* data = tensorflow::bit_cast(field->data()); uint64 bytes = elements * sizeof(float); tensorflow::StringPiece sp; auto s = file_->Read(offset_, bytes, &sp, data); diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 563d978cf5d..45a9fe01278 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/port/BUILD b/tensorflow/compiler/xla/port/BUILD deleted file mode 100644 index 6fc5f1185c9..00000000000 --- a/tensorflow/compiler/xla/port/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), - visibility = ["//tensorflow/compiler/xla:internal"], -) - -cc_library( - name = "initialize", - hdrs = ["initialize.h"], - visibility = [ - "//tensorflow/compiler/xla:__subpackages__", - ], -) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/port/initialize.h b/tensorflow/compiler/xla/port/initialize.h deleted file mode 100644 index 13d9632f97c..00000000000 --- a/tensorflow/compiler/xla/port/initialize.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_PORT_INITIALIZE_H_ -#define TENSORFLOW_COMPILER_XLA_PORT_INITIALIZE_H_ - -#undef REGISTER_MODULE_INITIALIZER - -namespace xla { - -class Initializer { - public: - typedef void (*InitializerFunc)(); - explicit Initializer(InitializerFunc func) { func(); } -}; - -} // namespace xla - -#define REGISTER_INITIALIZER(type, name, body) \ - static void google_init_##type##_##name() { body; } \ - xla::Initializer google_initializer_##type##_##name( \ - google_init_##type##_##name) - -#define REGISTER_MODULE_INITIALIZER(name, body) \ - REGISTER_INITIALIZER(module, name, body) - -#endif // TENSORFLOW_COMPILER_XLA_PORT_INITIALIZE_H_ diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index e3909ae8e97..e4e37177a2d 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType() { return F64; } +template <> +PrimitiveType NativeToPrimitiveType() { + return F16; +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64; } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 78f0ee6f592..162a11c7d29 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -75,6 +75,8 @@ template <> PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +template <> +PrimitiveType NativeToPrimitiveType(); bool IsFloatingPointType(PrimitiveType type); @@ -150,6 +152,10 @@ template <> struct PrimitiveTypeToNative { using type = double; }; +template <> +struct PrimitiveTypeToNative { + using type = half; +}; } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index adb2e99ad25..cdc4139cd69 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -14,7 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/protobuf_util.h" + +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/protobuf.h" namespace xla { namespace protobuf_util { @@ -31,5 +37,35 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, return (serialized1 == serialized2); } +StatusOr ToJson(const tensorflow::protobuf::Message& message) { + string json_output; + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.add_whitespace = true; + json_options.always_print_primitive_fields = true; + auto status = tensorflow::protobuf::util::MessageToJsonString( + message, &json_output, json_options); + if (!status.ok()) { + return InternalError("MessageToJsonString failed: %s", + status.error_message().data()); + } + return json_output; +} + +Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, + const string& directory, const string& file_name) { + TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message)); + + tensorflow::Env* env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string safe_file_name = file_name + ".json"; + for (char& c : safe_file_name) { + if (c == '/' || c == '\\') { + c = '_'; + } + } + const string path = tensorflow::io::JoinPath(directory, safe_file_name); + return tensorflow::WriteStringToFile(env, path, json_output); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 36247f1bdec..1a895c35859 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { @@ -29,6 +31,17 @@ namespace protobuf_util { // base, this form of equality checking is sufficient. extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, const tensorflow::protobuf::Message& m2); + +// Returns 'message' as a JSON string. +StatusOr ToJson(const tensorflow::protobuf::Message& message); + +// Converts 'message' to JSON, and dumps it to the path formed by joining +// 'directory/file_name.json'. The 'directory' is recursively created if it +// doesn't already exist, and the 'file_name' is sanitized by replacing illegal +// characters with underscore '_'. +Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, + const string& directory, const string& file_name); + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 142d2c2163f..e8de559a5ef 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/reference_util.h" #include +#include #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -134,12 +135,11 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); } -/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( - const Array4D& operand, float init, +/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( + const Array2D& operand, float init, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { - std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), - operand.n4()}; + std::vector dim_lengths{operand.height(), operand.width()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); std::vector window_counts(window.size(), 0); @@ -149,6 +149,61 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, WindowCount(dim_lengths[i], window[i], stride[i], padding); pad_low[i] = padding_both[i].first; } + auto result = MakeUnique>(window_counts[0], window_counts[1]); + + // Do a full 2D reduce window. + for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { + for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { + int64 i0_base = i0 * stride[0] - pad_low[0]; + int64 i1_base = i1 * stride[1] - pad_low[1]; + + float val = init; + for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { + for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { + if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && + i0_base + i0_win < operand.n1() && + i1_base + i1_win < operand.n2()) { + val += operand(i0_base + i0_win, i1_base + i1_win); + } + } + } + (*result)(i0, i1) = val; + } + } + return result; +} + +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow4DGeneric( + const Array4D& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; + return ReduceWindow4DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} + +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow4DGeneric( + const Array4D& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { + std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; + + std::vector window_counts(window.size(), 0); + std::vector pad_low(window.size(), 0); + for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; + window_counts[i] = + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; + } auto result = MakeUnique>(window_counts[0], window_counts[1], window_counts[2], window_counts[3]); // Do a full 4D reduce window. @@ -172,8 +227,9 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, i1_base + i1_win < operand.n2() && i2_base + i2_win < operand.n3() && i3_base + i3_win < operand.n4()) { - val += operand(i0_base + i0_win, i1_base + i1_win, - i2_base + i2_win, i3_base + i3_win); + val = reduce_func( + val, operand(i0_base + i0_win, i1_base + i1_win, + i2_base + i2_win, i3_base + i3_win)); } } } @@ -187,6 +243,15 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, return result; } +/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( + const Array4D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; + return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, + padding); +} + /* static */ std::unique_ptr> ReferenceUtil::SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, @@ -267,7 +332,8 @@ ReferenceUtil::ConvArray4DGeneralDimensions( std::pair kernel_stride, Padding padding, ConvolutionDimensionNumbers dimension_numbers) { return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, - {1, 1}, {1, 1}, dimension_numbers); + {1, 1}, {1, 1}, + std::move(dimension_numbers)); } /* static */ std::unique_ptr> @@ -335,32 +401,57 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( result_dimensions[2], result_dimensions[3]); result->Fill(0.0); + const auto is_int32 = [](int64 x) { + return x >= std::numeric_limits::min() && + x <= std::numeric_limits::max(); + }; + + // 64-bit idiv/mod are much more expensive x86-64 than 32-bit idiv/imod (at + // least on x86-64), so we avoid them where possible. + const auto fast_idiv64 = [&](int64 a, int64 b) { + if (is_int32(a) && is_int32(b)) { + return static_cast(static_cast(a) / static_cast(b)); + } + return a / b; + }; + const auto fast_imod64 = [&](int64 a, int64 b) { + if (is_int32(a) && is_int32(b)) { + return static_cast(static_cast(a) % static_cast(b)); + } + return a % b; + }; + // Lambda to access the lhs operand at the given 4D index. const auto lhs_element = [&](int64 batch, int64 feature, int64 height, int64 width) { - if (height % dy != 0 || width % dx != 0) { + if (fast_imod64(height, dy) != 0 || fast_imod64(width, dx) != 0) { return 0.0f; } std::array index; index[dnums.batch_dimension()] = batch; index[dnums.feature_dimension()] = feature; - index[dnums.spatial_dimensions(0)] = height / dy; - index[dnums.spatial_dimensions(1)] = width / dx; + index[dnums.spatial_dimensions(0)] = fast_idiv64(height, dy); + index[dnums.spatial_dimensions(1)] = fast_idiv64(width, dx); return lhs(index[0], index[1], index[2], index[3]); }; - // Lambda to access the rhs operand at the given 4D index. - const auto rhs_element = [&](int64 kernel_output_feature, - int64 kernel_input_feature, int64 height, - int64 width) { - CHECK_EQ(height % dky, 0); - CHECK_EQ(width % dkx, 0); + // Lambda to access the rhs operand at the given 4D index. height_over_dky + // should be equal to height / dky, and width_over_dkx should be equal to + // width / dkx. (This is an optimization to avoid doing divisions.) + const auto rhs_element = [&]( + int64 kernel_output_feature, int64 kernel_input_feature, int64 height, + int64 width, int64 height_over_dky, int64 width_over_dkx) { + DCHECK_EQ(height % dky, 0); + DCHECK_EQ(width % dkx, 0); + DCHECK_EQ(height / dky, height_over_dky); + DCHECK_EQ(width / dkx, width_over_dkx); + std::array index; index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; - index[dnums.kernel_spatial_dimensions(0)] = height / dky; - index[dnums.kernel_spatial_dimensions(1)] = width / dkx; + index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; + index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; return rhs(index[0], index[1], index[2], index[3]); }; @@ -380,14 +471,17 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( for (int64 sample = 0; sample < samples; ++sample) { for (int64 izi = 0; izi < iz; ++izi) { for (int64 ozi = 0; ozi < oz; ++ozi) { - for (int64 kyi = 0; kyi < ky; kyi += dky) { - for (int64 kxi = 0; kxi < kx; kxi += dkx) { + for (int64 kyi = 0, kyi_over_dky = 0; kyi < ky; + kyi += dky, kyi_over_dky++) { + for (int64 kxi = 0, kxi_over_dkx = 0; kxi < kx; + kxi += dkx, kxi_over_dkx++) { int64 iyi = istarty + ksy * oyi + kyi; int64 ixi = istartx + ksx * oxi + kxi; float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0) ? 0.0 : lhs_element(sample, izi, iyi, ixi); - float gain = rhs_element(ozi, izi, kyi, kxi); + float gain = + rhs_element(ozi, izi, kyi, kxi, kyi_over_dky, kxi_over_dkx); float addend = input * gain; result_element(sample, ozi, oyi, oxi) += addend; } @@ -571,4 +665,49 @@ ReferenceUtil::ReduceToRowArray2D( return result; } +/* static */ Array4D ReferenceUtil::PadArray4D( + const Array4D& operand, const PaddingConfig& padding, + const float pad) { + CHECK_EQ(padding.dimensions_size(), 4); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3(), operand.n4()}; + std::vector pad_low(4); + std::vector pad_high(4); + std::vector pad_interior(4); + std::vector output_bounds(4); + for (int64 i = 0; i < 4; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], + output_bounds[3]); + result.Each([&](tensorflow::gtl::ArraySlice indices, float* value) { + for (int i = 0; i < 4; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + return; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + return; + } + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1), + (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); + }); + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index d19d5f9dbb6..f58f0bdc9f5 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -144,12 +144,31 @@ class ReferenceUtil { static int64 WindowCount(int64 unpadded_width, int64 window_len, int64 stride, Padding padding); + // Performs a 2D window reduction with Add as the function to apply. + static std::unique_ptr> ReduceWindow2DAdd( + const Array2D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding); + // Performs a 4D window reduction with Add as the function to apply. static std::unique_ptr> ReduceWindow4DAdd( const Array4D& operand, float init, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + // Performs a 4D window reduction with a generic reduce function. + static std::unique_ptr> ReduceWindow4DGeneric( + const Array4D& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow4DGeneric( + const Array4D& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); + // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. static std::unique_ptr> SelectAndScatter4DGePlus( @@ -382,7 +401,51 @@ class ReferenceUtil { const Array2D& operand, const PaddingConfig& padding, const float pad); + // Returns the result of a 4D pad on an input array. + static Array4D PadArray4D(const Array4D& operand, + const PaddingConfig& padding, + const float pad); + + // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running + // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, .... + // + // The given arrays must have the same size and element type, and the return + // type of f must be implicitly convertible to the arrays' element type. + // + // Example usage: + // + // Array2D x, y, z = ...; + // std::unique_ptr result = ReferenceUtil::ApplyElementwise2D( + // [](float a, float b, float c) { return a * b + c; }, x, y, z); + // + template + static std::unique_ptr> ApplyElementwise2D( + F&& f, const Array2D& array1, const Array2D&... arrays) { + AssertSameSize2D(array1, arrays...); + auto result = MakeUnique>(array1.n1(), array1.n2()); + for (int64 i = 0; i < array1.n1(); ++i) { + for (int64 j = 0; j < array1.n2(); ++j) { + (*result)(i, j) = f(array1(i, j), arrays(i, j)...); + } + } + return result; + } + private: + template + static void AssertSameSize2D(const Array2D& array1, + const Array2D& array2, + const Array2D&... arrays) { + static_assert(std::is_same::value, "Args must be same type."); + CHECK_EQ(array1.n1(), array2.n1()); + CHECK_EQ(array1.n2(), array2.n2()); + AssertSameSize2D(array2, arrays...); + } + + // Recursive base case for AssertSameSize2D. + template + static void AssertSameSize2D(const Array1& array1) {} + TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil); }; diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index c53351ca93e..f839ac019df 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -52,9 +52,9 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MatmulArray2D) { @@ -62,32 +62,32 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto result_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *result_literal, + auto actual_literal = LiteralUtil::CreateR1(*result); + LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto result_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *result_literal, + auto actual_literal = LiteralUtil::CreateR1(*result); + LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *result_literal, + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -96,9 +96,9 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray4D) { @@ -107,11 +107,11 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); } @@ -124,11 +124,11 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); } @@ -302,5 +302,17 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { ErrorSpec(0.0001)); } +TEST_F(ReferenceUtilTest, ApplyElementwise2D) { + Array2D a({{1, 2}, {3, 4}}); + Array2D b({{10, 20}, {30, 40}}); + Array2D c({{100, 200}, {300, 400}}); + + auto actual = ReferenceUtil::ApplyElementwise2D( + [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); + LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, + *actual_literal, ErrorSpec(0.0001)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4a59ce2f17e..0687368b83d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -21,6 +21,14 @@ xla_proto_library( deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) +xla_proto_library( + name = "hlo_proto", + srcs = ["hlo.proto"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + ], +) + # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -48,10 +56,12 @@ cc_library( cc_test( name = "shape_inference_test", + size = "small", srcs = ["shape_inference_test.cc"], deps = [ ":shape_inference", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -61,11 +71,49 @@ cc_test( cc_test( name = "hlo_opcode_test", + size = "small", srcs = ["hlo_opcode_test.cc"], deps = [ ":hlo", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_evaluator", + srcs = ["hlo_evaluator.cc"], + hdrs = ["hlo_evaluator.h"], + deps = [ + ":hlo", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_evaluator_test", + size = "small", + srcs = ["hlo_evaluator_test.cc"], + deps = [ + ":hlo", + ":hlo_evaluator", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) @@ -88,6 +136,8 @@ cc_library( "hlo_opcode.h", ], deps = [ + ":hlo_module_config", + ":hlo_proto", ":name_uniquer", ":versioned_computation_handle", "//tensorflow/compiler/xla:literal_util", @@ -105,10 +155,34 @@ cc_library( ], ) +cc_library( + name = "hlo_matchers", + testonly = 1, + srcs = ["hlo_matchers.cc"], + hdrs = ["hlo_matchers.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:test_main", + ], +) + +cc_test( + name = "hlo_matchers_test", + size = "small", + srcs = ["hlo_matchers_test.cc"], + deps = [ + ":hlo_matchers", + "//tensorflow/compiler/xla:shape_util", + ], +) + cc_library( name = "versioned_computation_handle", + srcs = ["versioned_computation_handle.cc"], hdrs = ["versioned_computation_handle.h"], deps = [ + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], @@ -116,14 +190,82 @@ cc_library( cc_test( name = "hlo_instruction_test", + size = "small", srcs = ["hlo_instruction_test.cc"], deps = [ ":hlo", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", - "//tensorflow/core:test_main", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + +cc_library( + name = "call_graph", + srcs = ["call_graph.cc"], + hdrs = ["call_graph.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "call_graph_test", + size = "small", + srcs = ["call_graph_test.cc"], + deps = [ + ":call_graph", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "flatten_call_graph", + srcs = ["flatten_call_graph.cc"], + hdrs = ["flatten_call_graph.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "flatten_call_graph_test", + size = "small", + srcs = ["flatten_call_graph_test.cc"], + deps = [ + ":call_graph", + ":flatten_call_graph", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", ], ) @@ -143,10 +285,30 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/core:lib", ], ) +cc_test( + name = "user_computation_test", + size = "small", + srcs = ["user_computation_test.cc"], + deps = [ + ":hlo_matchers", + ":user_computation", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + ], +) + cc_library( name = "platform_util", srcs = ["platform_util.cc"], @@ -170,6 +332,7 @@ cc_library( ":compiler", ":device_memory_allocator", ":platform_util", + ":pool", ":transfer_manager", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -199,10 +362,10 @@ cc_library( ":device_memory_allocator", ":executable", ":execution_tracker", + ":gpu_transfer_manager", ":hlo", ":hlo_cost_analysis", ":hlo_execution_profile", - ":hlo_graph_dumper", ":hlo_module_config", ":platform_util", ":session_proto", @@ -222,7 +385,6 @@ cc_library( "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", ], alwayslink = 1, @@ -254,6 +416,29 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "compile_only_service", + srcs = ["compile_only_service.cc"], + hdrs = ["compile_only_service.h"], + deps = [ + ":backend", + ":compiler", + ":computation_layout", + ":computation_tracker", + ":platform_util", + ":service", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -272,7 +457,7 @@ cc_library( cc_library( name = "gpu_plugin", deps = [ - ":generic_transfer_manager", + ":gpu_transfer_manager", ":service", "//tensorflow/compiler/xla/service/gpu:gpu_compiler", "//tensorflow/core:stream_executor_no_cuda", @@ -301,23 +486,31 @@ cc_library( cc_library( name = "executable", srcs = ["executable.cc"], - hdrs = ["executable.h"], + hdrs = [ + "executable.h", + "service_executable_run_options.h", + ], deps = [ ":computation_layout", ":device_memory_allocator", ":hlo", + ":hlo_cost_analysis", ":hlo_execution_profile", - ":hlo_module_config", + ":hlo_graph_dumper", + ":pool", ":session_proto", ":shaped_buffer", ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor", ], ) @@ -329,6 +522,7 @@ cc_library( ":executable", ":hlo", ":hlo_module_config", + ":logical_buffer", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -342,6 +536,7 @@ cc_library( srcs = ["transfer_manager.cc"], hdrs = ["transfer_manager.h"], deps = [ + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -378,6 +573,7 @@ cc_library( hdrs = ["execution_tracker.h"], deps = [ ":backend", + ":pool", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -393,6 +589,7 @@ cc_library( hdrs = ["computation_tracker.h"], deps = [ ":hlo", + ":hlo_module_config", ":session_proto", ":user_computation", ":versioned_computation_handle", @@ -435,6 +632,32 @@ cc_library( ], ) +cc_library( + name = "liveness_util", + srcs = ["liveness_util.cc"], + hdrs = ["liveness_util.h"], + deps = [ + ":hlo", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + ], +) + +cc_test( + name = "liveness_util_test", + size = "small", + srcs = ["liveness_util_test.cc"], + deps = [ + ":hlo", + ":liveness_util", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + cc_library( name = "buffer_liveness", srcs = [ @@ -446,6 +669,7 @@ cc_library( deps = [ ":hlo", ":hlo_ordering", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -459,10 +683,10 @@ cc_library( cc_test( name = "buffer_liveness_test", + size = "small", srcs = ["buffer_liveness_test.cc"], deps = [ ":buffer_liveness", - ":cpu_plugin", ":hlo", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -484,6 +708,8 @@ cc_library( deps = [ ":buffer_liveness", ":hlo", + ":hlo_ordering", + ":hlo_proto", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -494,38 +720,68 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) cc_test( name = "buffer_assignment_test", + size = "small", srcs = ["buffer_assignment_test.cc"], deps = [ ":buffer_assignment", + ":call_graph", ":computation_tracker", - ":cpu_plugin", + ":copy_insertion", + ":flatten_call_graph", ":hlo", + ":hlo_ordering", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", + ], +) + +cc_test( + name = "heap_simulator_test", + size = "small", + srcs = ["heap_simulator_test.cc"], + deps = [ + ":hlo", + ":hlo_ordering", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) +# The hlo_ordering library contains both hlo_ordering and heap_simulator because +# they are mutually dependent. cc_library( name = "hlo_ordering", srcs = [ + "heap_simulator.cc", "hlo_ordering.cc", ], hdrs = [ + "heap_simulator.h", "hlo_ordering.h", ], deps = [ + ":call_graph", ":hlo", + ":hlo_proto", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -539,16 +795,15 @@ cc_library( cc_test( name = "hlo_ordering_test", + size = "small", srcs = ["hlo_ordering_test.cc"], deps = [ - ":cpu_plugin", ":hlo", ":hlo_ordering", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:test_main", ], ) @@ -577,11 +832,12 @@ cc_library( cc_test( name = "instruction_fusion_test", + size = "small", srcs = ["instruction_fusion_test.cc"], deps = [ + ":hlo_matchers", ":instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:test_main", ], ) @@ -607,21 +863,21 @@ cc_library( cc_test( name = "algebraic_simplifier_test", + size = "small", srcs = ["algebraic_simplifier_test.cc"], deps = [ ":algebraic_simplifier", - ":cpu_plugin", ":hlo", + ":hlo_matchers", ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", - "//tensorflow/core:test_main", ], ) @@ -631,26 +887,31 @@ cc_library( hdrs = ["reshape_mover.h"], deps = [ ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", ], ) cc_test( name = "reshape_mover_test", + size = "small", srcs = ["reshape_mover_test.cc"], deps = [ ":hlo", + ":hlo_matchers", ":reshape_mover", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", - "//tensorflow/core:test_main", ], ) @@ -670,14 +931,15 @@ cc_library( cc_test( name = "inliner_test", + size = "small", srcs = ["inliner_test.cc"], deps = [ - ":cpu_plugin", ":hlo", + ":hlo_matchers", ":inliner", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -726,8 +988,30 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) +cc_library( + name = "gpu_transfer_manager", + srcs = ["gpu_transfer_manager.cc"], + hdrs = ["gpu_transfer_manager.h"], + deps = [ + ":generic_transfer_manager", + ":transfer_manager", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/gpu:infeed_manager", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform transfer manager registration +) + cc_test( name = "transfer_manager_test", + size = "small", srcs = ["transfer_manager_test.cc"], deps = [ ":cpu_transfer_manager", @@ -765,6 +1049,7 @@ cc_library( cc_test( name = "hlo_cost_analysis_test", + size = "small", srcs = ["hlo_cost_analysis_test.cc"], deps = [ ":computation_tracker", @@ -805,12 +1090,14 @@ cc_library( cc_test( name = "hlo_computation_test", + size = "small", srcs = ["hlo_computation_test.cc"], deps = [ - ":cpu_plugin", ":hlo", + ":hlo_matchers", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test_main", @@ -834,18 +1121,17 @@ cc_binary( cc_test( name = "hlo_module_test", + size = "small", srcs = ["hlo_module_test.cc"], deps = [ - ":cpu_plugin", ":hlo", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", - "//tensorflow/core:test_main", ], ) @@ -859,10 +1145,101 @@ cc_library( ], deps = [ ":hlo", + ":hlo_proto", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "hlo_dataflow_analysis", + srcs = [ + "hlo_dataflow_analysis.cc", + ], + hdrs = [ + "hlo_dataflow_analysis.h", + ], + deps = [ + ":call_graph", + ":hlo", + ":liveness_util", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_test( + name = "hlo_dataflow_analysis_test", + size = "small", + srcs = ["hlo_dataflow_analysis_test.cc"], + deps = [ + ":hlo", + ":hlo_dataflow_analysis", + ":hlo_matchers", + ":instruction_fusion", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_alias_analysis", + srcs = [ + "hlo_alias_analysis.cc", + ], + hdrs = [ + "hlo_alias_analysis.h", + ], + deps = [ + ":call_graph", + ":hlo", + ":hlo_dataflow_analysis", + ":logical_buffer", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_alias_analysis_test", + srcs = ["hlo_alias_analysis_test.cc"], + deps = [ + ":hlo", + ":hlo_alias_analysis", + ":hlo_matchers", + ":instruction_fusion", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) @@ -889,12 +1266,16 @@ cc_library( cc_test( name = "tuple_points_to_analysis_test", + size = "small", srcs = ["tuple_points_to_analysis_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":instruction_fusion", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -956,6 +1337,7 @@ cc_library( ":buffer_liveness", ":hlo", ":hlo_pass", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:status_macros", @@ -968,19 +1350,19 @@ cc_library( cc_test( name = "copy_insertion_test", + size = "small", srcs = ["copy_insertion_test.cc"], deps = [ - ":buffer_liveness", ":copy_insertion", - ":cpu_plugin", ":hlo", + ":hlo_matchers", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:test_main", ], ) @@ -1000,11 +1382,59 @@ cc_library( ], ) +cc_library( + name = "hlo_verifier", + srcs = ["hlo_verifier.cc"], + hdrs = ["hlo_verifier.h"], + deps = [":hlo_pass"], +) + +cc_library( + name = "hlo_rematerialization", + srcs = ["hlo_rematerialization.cc"], + hdrs = ["hlo_rematerialization.h"], + deps = [ + ":buffer_liveness", + ":call_graph", + ":flatten_call_graph", + ":hlo", + ":hlo_cost_analysis", + ":hlo_dce", + ":hlo_ordering", + ":liveness_util", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_rematerialization_test", + size = "small", + srcs = ["hlo_rematerialization_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":hlo_ordering", + ":hlo_rematerialization", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + cc_test( name = "hlo_dce_test", + size = "small", srcs = ["hlo_dce_test.cc"], deps = [ - ":cpu_plugin", ":hlo", ":hlo_dce", "//tensorflow/compiler/xla:literal_util", @@ -1016,29 +1446,30 @@ cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", - "//tensorflow/core:test_main", + "//tensorflow/core:test", ], ) cc_test( name = "layout_assignment_test", + size = "small", srcs = ["layout_assignment_test.cc"], deps = [ ":algebraic_simplifier", ":computation_layout", - ":cpu_plugin", ":hlo", + ":hlo_matchers", ":layout_assignment", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", - "//tensorflow/core:test_main", ], ) @@ -1073,7 +1504,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", "//tensorflow/core:lib", ], ) @@ -1087,7 +1517,6 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", @@ -1096,11 +1525,12 @@ cc_library( cc_test( name = "hlo_cse_test", + size = "small", srcs = ["hlo_cse_test.cc"], deps = [ - ":cpu_plugin", ":hlo", ":hlo_cse", + ":hlo_matchers", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1110,7 +1540,42 @@ cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", - "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_constant_folding", + srcs = ["hlo_constant_folding.cc"], + hdrs = ["hlo_constant_folding.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_constant_folding_test", + size = "small", + srcs = ["hlo_constant_folding_test.cc"], + deps = [ + ":hlo", + ":hlo_constant_folding", + ":hlo_matchers", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", ], ) @@ -1159,6 +1624,7 @@ cc_library( ":computation_layout", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", @@ -1188,6 +1654,7 @@ cc_library( cc_test( name = "hlo_subcomputation_unification_test", + size = "small", srcs = ["hlo_subcomputation_unification_test.cc"], deps = [ ":hlo", @@ -1196,7 +1663,33 @@ cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_tfgraph_builder", + srcs = ["hlo_tfgraph_builder.cc"], + hdrs = ["hlo_tfgraph_builder.h"], + visibility = ["//tensorflow/compiler/xla/tools:__pkg__"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "hlo_tfgraph_builder_test", + size = "small", + srcs = ["hlo_tfgraph_builder_test.cc"], + deps = [ + ":hlo_tfgraph_builder", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:protos_all_cc", ], ) @@ -1209,9 +1702,11 @@ cc_library( deps = [ ":hlo", ":hlo_execution_profile", + ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", "//tensorflow/core:lib", ], @@ -1225,7 +1720,9 @@ cc_library( deps = [ ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/core:lib", ], @@ -1233,20 +1730,57 @@ cc_library( cc_test( name = "transpose_folding_test", + size = "small", srcs = ["transpose_folding_test.cc"], deps = [ ":hlo", + ":shape_inference", ":transpose_folding", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) +cc_library( + name = "pool", + hdrs = ["pool.h"], + deps = [ + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "pool_test", + size = "small", + srcs = ["pool_test.cc"], + deps = [ + ":pool", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_proto_util", + srcs = ["hlo_proto_util.cc"], + hdrs = ["hlo_proto_util.h"], + deps = [ + ":buffer_assignment", + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:status", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index d35c6d6adb0..754ac0c68dc 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -51,6 +51,16 @@ bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { LiteralUtil::IsAll(operand->literal(), value); } +bool IsAll(const HloInstruction* op, int8 value) { + if (IsLiteralWithValue(op, value)) { + return true; + } + if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) { + return true; + } + return false; +} + // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. bool TransposeIsBitcast(const HloInstruction* transpose) { @@ -76,6 +86,24 @@ bool ReshapeIsBitcast( return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) && valid_bitcast_callback(operand->shape(), reshape->shape()); } + +// Adds a scalar computation to the module to enable optimizations with dot +// converting into reduction. +HloComputation* CreateScalarBinaryComputation(HloModule* module, + PrimitiveType primitive_type, + HloOpcode opcode) { + HloComputation::Builder b("scalar computation"); + auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "scalar lhs")); + auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "scalar rhs")); + auto scalar_op = b.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), + opcode, scalar_lhs, scalar_rhs)); + HloComputation* scalar_computation = + module->AddEmbeddedComputation(b.Build(scalar_op)); + return scalar_computation; +} } // namespace // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain @@ -94,6 +122,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; Status HandleConvert(HloInstruction* convert, @@ -105,6 +137,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, HloInstruction* rhs) override; + Status HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; @@ -125,7 +160,19 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice dimensions, HloComputation* function) override; + Status HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, const Window& window, + HloComputation* function) override; + + Status HandleReverse(HloInstruction* reverse, + HloInstruction* operand) override; Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + Status HandleDynamicSlice(HloInstruction* slice, HloInstruction* operand, + HloInstruction* start_indices) override; + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices) override; Status HandleTranspose(HloInstruction* transpose) override; @@ -144,15 +191,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run( HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback); + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, + bool enable_dot_simplification); private: explicit AlgebraicSimplifierVisitor( HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, + bool enable_dot_simplification) : computation_(computation), is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + valid_bitcast_callback_(std::move(valid_bitcast_callback)), + enable_dot_simplification_(enable_dot_simplification) {} // Convenience method for replacing an instruction with a bitcast. void ReplaceWithBitcast(HloInstruction* instruction); @@ -179,6 +229,34 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand); + // A Reshape or Broadcast that feeds an element-wise operation with a unique + // non-scalar operand can sink to after the operation. + StatusOr TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* reshape_or_broadcast); + + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceWithNewInstruction( + HloInstruction* old_instruction, + std::unique_ptr new_instruction) { + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + old_instruction, std::move(new_instruction))); + changed_ = true; + return Status::OK(); + } + + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceInstruction(HloInstruction* old_instruction, + HloInstruction* new_instruction) { + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(old_instruction, new_instruction)); + changed_ = true; + return Status::OK(); + } + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -191,13 +269,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Callback used to determine if a bitcast is possible. AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; + + // Disable dot simplication on platforms where it causes a slowdown. + bool enable_dot_simplification_; }; bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) { + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, + bool enable_dot_simplification) { AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive, - std::move(valid_bitcast_callback)); + std::move(valid_bitcast_callback), + enable_dot_simplification); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -222,8 +305,7 @@ void AlgebraicSimplifierVisitor::ReplaceWithBitcast( auto bitcast = computation_->AddInstruction( HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, instruction->mutable_operand(0))); - TF_CHECK_OK(computation_->ReplaceInstruction(instruction, bitcast)); - changed_ = true; + TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); } bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( @@ -231,9 +313,7 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( if (!SameShape(old_instruction, new_instruction)) { return false; } - TF_CHECK_OK( - computation_->ReplaceInstruction(old_instruction, new_instruction)); - changed_ = true; + TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction)); return true; } @@ -242,12 +322,12 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, HloInstruction* rhs) { // A + 0 => A VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); - if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { return Status::OK(); } // 0 + A => A VLOG(10) << "trying transform [0 + A => A]: " << add->ToString(); - if (IsLiteralWithValue(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { return Status::OK(); } @@ -256,17 +336,91 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, HloInstruction* operand) { + // If a copy feeds a copy, make it a single copy. + if (operand->opcode() == HloOpcode::kCopy) { + return ReplaceWithNewInstruction( + copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, + operand->operands()[0])); + } // All copies can be eliminated (assuming layout constraints are satisified). ReplaceInstructionIfSameShape(copy, operand); return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + if (operands.size() == 1) { + // Unary concatenates are useless. + ReplaceInstructionIfSameShape(concatenate, operands[0]); + return Status::OK(); + } + // Filter out and remove empty operands. + std::vector nonempty_operands; + for (HloInstruction* operand : operands) { + if (!ShapeUtil::HasZeroElements(operand->shape())) { + nonempty_operands.push_back(operand); + } + } + if (nonempty_operands.size() < operands.size()) { + HloInstruction* replacement; + if (nonempty_operands.empty()) { + replacement = operands[0]; + } else if (nonempty_operands.size() == 1) { + replacement = nonempty_operands[0]; + } else { + replacement = + computation_->AddInstruction(concatenate->CloneWithNewOperands( + concatenate->shape(), nonempty_operands)); + } + VLOG(10) << "trying to replace " << concatenate->ToString() << " with " + << replacement->ToString(); + ReplaceInstructionIfSameShape(concatenate, replacement); + } else if (operands.size() == 2) { + // A binary concat with a broadcasted scalar as an operand can be converted + // into a pad which is simpler to fold into other operations. + bool is_effective_low_pad = + operands[0]->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(operands[0]->operand(0)->shape()); + bool is_effective_high_pad = + operands[1]->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(operands[1]->operand(0)->shape()); + if (!is_effective_low_pad && !is_effective_high_pad) { + return Status::OK(); + } + PaddingConfig padding_config; + for (int64 dim = 0; dim < ShapeUtil::Rank(operands[0]->shape()); ++dim) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_edge_padding_high(0); + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_interior_padding(0); + if (dim == concatenate->concatenate_dimension()) { + if (is_effective_low_pad) { + padding_config_dim->set_edge_padding_low( + operands[0]->shape().dimensions(dim)); + } else { + padding_config_dim->set_edge_padding_high( + operands[1]->shape().dimensions(dim)); + } + } + } + int64 operand_to_pad = is_effective_low_pad ? 1 : 0; + int64 pad_value_operand = is_effective_low_pad ? 0 : 1; + HloInstruction* pad = + computation_->AddInstruction(HloInstruction::CreatePad( + concatenate->shape(), operands[operand_to_pad], + operands[pad_value_operand]->mutable_operand(0), padding_config)); + return ReplaceInstruction(concatenate, pad); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub, HloInstruction* lhs, HloInstruction* rhs) { // A - 0 => A VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); - if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { return Status::OK(); } @@ -278,8 +432,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, HloInstruction* rhs) { // A/1 => A VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); - if (IsLiteralWithValue(rhs, 1) && - ReplaceInstructionIfSameShape(divide, lhs)) { + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { return Status::OK(); } @@ -290,8 +443,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, computation_->AddInstruction(HloInstruction::CreateBinary( divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), rhs->mutable_operand(0))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, subtract)); } @@ -299,19 +451,148 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, + HloInstruction* lhs, + HloInstruction* rhs) { + if (!enable_dot_simplification_) { + return Status::OK(); + } + // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or + // below. + if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || + ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + return Status::OK(); + } + + // Replace a zero element dot with a broadcast of the constant 0. + if (ShapeUtil::HasZeroElements(dot->shape()) || + ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { + auto zero = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); + } + + // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). + if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, + rhs->mutable_operand(0), lhs->mutable_operand(0))); + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); + } + + // Simplify outer product into multiply with implicit broadcasting. + // + // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) + if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) { + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, + lhs, rhs)); + } + + // The following graph transformations take Dots where at least one input is a + // vector or has a degenerate dimension and converts it into a multiply and + // reduce. This should enable more fusion than leaving the nodes as Dot + // operations. + + // Strength reduce dot(a[K] , b[K]) = + // reshape(result.shape, + // reduce_sum(multiply(a, b), {0})) + if (ShapeUtil::Rank(rhs->shape()) == 1 && + ShapeUtil::Rank(lhs->shape()) == 1) { + auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kMultiply, lhs, rhs)); + HloComputation* add_reduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + auto zero = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, + {0}, add_reduce_computation)); + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + } + + // Strength reduce dot(a[1, K], b) = + // reshape(result.shape, + // reduce_sum( + // multiply(broadcast(reshape(a, [K]), {0}), b), + // {0}) + // ) + // ) + if (ShapeUtil::Rank(lhs->shape()) == 1 || + (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) { + auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs->shape().element_type(), + {ShapeUtil::ElementsIn(lhs->shape())}), + lhs)); + HloComputation* add_reduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + auto zero = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* reduce; + if (ShapeUtil::Rank(rhs->shape()) == 1) { + auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); + reduce = computation_->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, + {0}, add_reduce_computation)); + } else { + new_lhs = computation_->AddInstruction( + HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0})); + auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); + + reduce = computation_->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(dot->shape().element_type(), + {rhs->shape().dimensions(1)}), + multiply, zero, {0}, add_reduce_computation)); + } + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + } + + // Strength reduce dot(a, b[K, 1]) = + // reshape(result.shape, + // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) + // ) + if (ShapeUtil::Rank(rhs->shape()) == 1 || + (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) { + auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(rhs->shape().element_type(), + {ShapeUtil::ElementsIn(rhs->shape())}), + rhs)); + new_rhs = computation_->AddInstruction( + HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1})); + auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( + lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs)); + HloComputation* add_reduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + auto zero = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(dot->shape().element_type(), + {lhs->shape().dimensions(0)}), + multiply, zero, {1}, add_reduce_computation)); + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, HloInstruction* rhs) { // A*1 => A VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); - if (IsLiteralWithValue(rhs, 1) && - ReplaceInstructionIfSameShape(multiply, lhs)) { + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { return Status::OK(); } // 1*A => A VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString(); - if (IsLiteralWithValue(lhs, 1) && - ReplaceInstructionIfSameShape(multiply, rhs)) { + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) { return Status::OK(); } return Status::OK(); @@ -383,8 +664,9 @@ std::pair> ReshapeLeavesDimensionsUnmodified( return std::make_pair(true, output_dim_indices); } -// Returns true if the output of "instruction" is a permutation of the elements -// of "operand". Precondition: "operand" is an operand of "instruction". +// Returns true if the output of "instruction" is a permutation of the +// elements of "operand". Precondition: "operand" is an operand of +// "instruction". bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, HloInstruction* operand) { DCHECK(!instruction->OperandIndices(operand).empty()); @@ -432,13 +714,25 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> reshape(X) where " "n(broadcast(X)) == n(X)"; - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand)); } - // A broadcast of a reshape which merely inserts 1-sized dimensions can elide - // its operand. + // A degenerate broadcast that has the same input and output rank can be + // converted into a transpose. + if (ShapeUtil::Rank(broadcast->shape()) == + ShapeUtil::Rank(operand->shape()) && + ShapeUtil::ElementsIn(broadcast->shape()) == + ShapeUtil::ElementsIn(operand->shape())) { + VLOG(10) << "transform broadcast(X) -> transpose(X) where " + "n(broadcast(X)) == n(X)"; + return ReplaceWithNewInstruction( + broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, + broadcast->dimensions())); + } + + // A broadcast of a reshape which merely inserts 1-sized dimensions can + // elide its operand. { bool merely_inserts_or_deletes_1_sized_dimensions; std::vector inserted_indices, deleted_indices; @@ -452,14 +746,22 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { for (auto inserted_index : inserted_indices) { dims.erase(dims.begin() + inserted_index); } - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( broadcast, HloInstruction::CreateBroadcast(broadcast->shape(), operand->mutable_operand(0), dims)); } } + // A Broadcast that feeds a unary element-wise operation can sink the + // broadcast after the unary element-wise operation. + TF_ASSIGN_OR_RETURN( + changed_, + TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); + if (changed_) { + return Status::OK(); + } + // A scalar broadcast feeding an instruction which only permutes (reshape, // transpose, sort, reverse) or selects a subset of operand elements (slice, // dynamic slice) can be replaced with a broadcast directly to the output @@ -487,65 +789,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } -template -static std::unique_ptr ConvertIfTypesMatch( - const Literal& src_literal) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - - return HloInstruction::CreateConstant( - LiteralUtil::Convert::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal)); -} - -template -static std::unique_ptr ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -static std::unique_ptr ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - // A conversion to the same element type as the operand is a nop and can be // removed. A conversion of a constant can be simplified by making a new // constant. @@ -554,16 +797,7 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert, PrimitiveType src_type = operand->shape().element_type(); PrimitiveType dest_type = convert->shape().element_type(); if (src_type == dest_type) { - changed_ = true; - return computation_->ReplaceInstruction(convert, operand); - } - if (operand->opcode() == HloOpcode::kConstant) { - const Literal& src_literal = operand->literal(); - std::unique_ptr new_constant = - ConvertIfSrcTypeMatches(src_literal, dest_type); - changed_ = true; - return computation_->ReplaceWithNewInstruction(convert, - std::move(new_constant)); + return ReplaceInstruction(convert, operand); } return Status::OK(); } @@ -626,6 +860,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { // Second, construct the slice instruction to perform the negative padding. std::vector start_indices; std::vector end_indices; + std::vector strides; for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) { const PaddingConfig::PaddingConfigDimension& padding_dimension = pad->padding_config().dimensions(i); @@ -639,18 +874,19 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { } start_indices.push_back(start); end_indices.push_back(end); + strides.push_back(1); } // Verify that the slice shape matches the pad shape. TF_ASSIGN_OR_RETURN(Shape inferred_slice_shape, ShapeInference::InferSliceShape( - nonzero_pad_shape, start_indices, end_indices)); + nonzero_pad_shape, start_indices, end_indices, + strides)); TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape())); std::unique_ptr slice = HloInstruction::CreateSlice( - pad->shape(), nonzero_pad, start_indices, end_indices); - changed_ = true; - return computation_->ReplaceWithNewInstruction(pad, std::move(slice)); + pad->shape(), nonzero_pad, start_indices, end_indices, strides); + return ReplaceWithNewInstruction(pad, std::move(slice)); } return Status::OK(); @@ -660,7 +896,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, HloInstruction* lhs, HloInstruction* rhs) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); - if (IsLiteralWithValue(rhs, 0)) { + if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( LiteralUtil::One(power->shape().element_type()))); std::unique_ptr ones; @@ -670,51 +906,122 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, ones = HloInstruction::CreateBroadcast( power->shape(), computation_->AddInstruction(std::move(one)), {}); } - changed_ = true; - return computation_->ReplaceWithNewInstruction(power, std::move(ones)); - return Status::OK(); + return ReplaceWithNewInstruction(power, std::move(ones)); } VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString(); - if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { return Status::OK(); } VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); - if (IsLiteralWithValue(rhs, 2)) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + if (IsAll(rhs, 2)) { + return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kMultiply, lhs, lhs)); } VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); - if (IsLiteralWithValue(rhs, -1)) { + if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( LiteralUtil::One(rhs->shape().element_type())))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, one, lhs)); } return Status::OK(); } +StatusOr AlgebraicSimplifierVisitor:: + TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* reshape_or_broadcast) { + bool changed = false; + HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); + for (HloInstruction* user : reshape_or_broadcast->users()) { + if (user->user_count() == 0 && user != computation_->root_instruction()) { + continue; + } + // Do not move reshapes or broadcasts past copies since the shape the copy + // will operate on will change. + if (user->opcode() == HloOpcode::kCopy) { + continue; + } + // Do not change the shape of fusion nodes in case there a multiple shapes + // inside the fusion node already. + if (user->opcode() == HloOpcode::kFusion) { + continue; + } + if (!user->IsElementwise()) { + continue; + } + + int64 reshape_or_broadcast_operand_index = -1; + // Find the unique non-scalar operand or continue if there isn't one. + int64 scalar_count = 0; + for (int64 i = 0; i < user->operand_count(); ++i) { + if (ShapeUtil::IsScalar(user->operand(i)->shape())) { + ++scalar_count; + } else { + reshape_or_broadcast_operand_index = i; + } + } + if (scalar_count != user->operand_count() - 1) { + continue; + } + CHECK_EQ(user->operand(reshape_or_broadcast_operand_index), + reshape_or_broadcast); + std::vector new_user_operands = user->operands(); + new_user_operands[reshape_or_broadcast_operand_index] = operand; + auto new_user = computation_->AddInstruction(user->CloneWithNewOperands( + ShapeUtil::MakeShape(user->shape().element_type(), + AsInt64Slice(operand->shape().dimensions())), + new_user_operands)); + HloInstruction* new_reshape_or_broadcast = nullptr; + if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) { + new_reshape_or_broadcast = + computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape( + user->shape().element_type(), + AsInt64Slice(reshape_or_broadcast->shape().dimensions())), + new_user)); + } else { + TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); + new_reshape_or_broadcast = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape( + user->shape().element_type(), + AsInt64Slice(reshape_or_broadcast->shape().dimensions())), + new_user, reshape_or_broadcast->dimensions())); + } + TF_RETURN_IF_ERROR( + computation_->ReplaceUsesOfInstruction(user, new_reshape_or_broadcast)); + changed = true; + } + return changed; +} + Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { auto operand = reshape->mutable_operand(0); + // Reshape directly to empty constant if the shape contains zero-element + // dimension. + if (ShapeUtil::HasZeroElements(reshape->shape())) { + auto empty_constant = HloInstruction::CreateConstant( + LiteralUtil::CreateFromShape(reshape->shape())); + + return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); + } + // Delete no-op reshapes, i.e. where shape = operand shape. if (SameShape(reshape, operand)) { VLOG(10) << "deleting no-op reshape"; - changed_ = true; - return computation_->ReplaceInstruction(reshape, operand); + return ReplaceInstruction(reshape, operand); } // Merge reshapes. if (HloOpcode::kReshape == operand->opcode()) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } @@ -723,8 +1030,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( reshape, reshape->operand(0)->dimensions()); if (opt_dims.first) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reshape, HloInstruction::CreateBroadcast( reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), @@ -732,6 +1038,15 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } } + // A Reshape that feeds a unary element-wise operation can sink the + // reshape after the unary element-wise operation. + TF_ASSIGN_OR_RETURN( + changed_, + TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); + if (changed_) { + return Status::OK(); + } + // Make this a bitcast if possible. if (is_layout_sensitive_ && ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { @@ -742,6 +1057,20 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse, + HloInstruction* operand) { + // When all the dimensions to reverse are trivial (i.e. the bound is 1), + // there is nothing to be done. + auto dim_is_one = [&](int64 i) -> bool { + return reverse->shape().dimensions(i) == 1; + }; + if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), + dim_is_one)) { + return ReplaceInstruction(reverse, operand); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice, HloInstruction* operand) { // Delete no-op slices, i.e. where shape = operand shape. @@ -751,34 +1080,176 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice, return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleDynamicSlice( + HloInstruction* dynamic_slice, HloInstruction* operand, + HloInstruction* start_indices) { + if (ShapeUtil::IsScalar(dynamic_slice->shape())) { + return ReplaceInstruction(dynamic_slice, operand); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice, HloInstruction* operand, + HloInstruction* update, HloInstruction* start_indices) { + // DynamicUpdateSlice on a scalar just passes through the update argument. + if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { + return ReplaceInstruction(dynamic_update_slice, update); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { + if (ShapeUtil::HasZeroElements(arg->shape()) || + ShapeUtil::HasZeroElements(reduce->shape())) { + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); + return Status::OK(); + } + // A Transpose feeding a reduce can simply permute the reduction dimensions + // field. + if (arg->opcode() == HloOpcode::kTranspose) { + auto transpose_dimensions = arg->dimensions(); + std::vector new_reduce_dimensions; + for (auto dim : dimensions) { + new_reduce_dimensions.push_back(transpose_dimensions[dim]); + } + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReduce( + reduce->shape(), arg->mutable_operand(0), init_value, + new_reduce_dimensions, function)); + } + + // A reshape that collapses multiple dimensions into a dimension being + // reduced can just reduce all of those dimensions instead of doing a + // collapsing reshape before a reduction. + if (arg->opcode() == HloOpcode::kReshape) { + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), + arg->shape()); + std::vector arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true); + std::vector arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false); + for (auto dim : dimensions) { + arg_dim_in_output[dim] = false; + } + for (auto dim_pair : unmodified_dims) { + arg_dim_unmodified[dim_pair.second] = true; + } + // The goal is to verify that all dimensions that are not removed in the + // reduce are unmodified by the reshape. For example: + // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2]) + bool can_move_reshape_into_reduce = true; + for (int64 i = 0; i < arg_dim_in_output.size(); ++i) { + if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) { + can_move_reshape_into_reduce = false; + } + } + if (can_move_reshape_into_reduce) { + changed_ = true; + std::unordered_set dimensions_not_to_reduce; + for (auto dim_pair : unmodified_dims) { + if (arg_dim_in_output[dim_pair.second]) { + dimensions_not_to_reduce.insert(dim_pair.first); + } + } + std::vector new_reduce_dimensions; + for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) { + if (dimensions_not_to_reduce.count(i) == 0) { + new_reduce_dimensions.push_back(i); + } + } + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReduce( + reduce->shape(), arg->mutable_operand(0), init_value, + new_reduce_dimensions, function)); + } + } if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape())) { + ShapeUtil::ElementsIn(arg->shape()) || + ShapeUtil::HasZeroElements(arg->shape())) { auto reshape = computation_->AddInstruction( HloInstruction::CreateReshape(reduce->shape(), arg)); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reduce, HloInstruction::CreateMap(reduce->shape(), {reshape, init_value}, function)); } return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleReduceWindow( + HloInstruction* reduce_window, HloInstruction* operand, + const Window& window, HloComputation* function) { + VLOG(10) << "Considering folding Pad: " << operand->ToString() + << "\ninto reduce-window: " << reduce_window->ToString(); + + // This optimization folds a pad op into reduce_window. + if (operand->opcode() != HloOpcode::kPad) { + VLOG(10) << "Not folding pad into reduce-window as there is no pad."; + return Status::OK(); + } + + // Do not fold interior padding into ReduceWindow since the backends do not + // support it. + const PaddingConfig& pad_config = operand->padding_config(); + if (HasInteriorPadding(pad_config)) { + VLOG(10) << "Not folding pad into reduce-window due to interior padding."; + return Status::OK(); + } + + // If reduce_window already has padding, the pad value of the pad op and the + // init value of reduce_window must match to allow folding the pad. + const HloInstruction* pad_value = operand->operand(1); + const HloInstruction* reduce_init_value = reduce_window->operand(1); + if (pad_value != reduce_init_value) { + // The pad value is usually a constant, so we handle that case and do not + // try to get more fancy about proving equivalence in cases beyond that. + if (pad_value->opcode() != HloOpcode::kConstant || + reduce_init_value->opcode() != HloOpcode::kConstant || + !LiteralUtil::Equal(pad_value->literal(), + reduce_init_value->literal())) { + VLOG(10) << "Not folding pad into reduce-window due to different pad " + "values."; + return Status::OK(); + } + } + + // Carry out the folding of the pad into reduce_window. + VLOG(10) << "Folding pad into reduce-window."; + Window new_window = window; + const int64 rank = ShapeUtil::Rank(reduce_window->shape()); + TF_RET_CHECK(pad_config.dimensions_size() == rank); + TF_RET_CHECK(window.dimensions_size() == rank); + for (int64 i = 0; i < rank; ++i) { + const auto& pad_dim = pad_config.dimensions(i); + auto& window_dim = *new_window.mutable_dimensions(i); + window_dim.set_padding_low(window_dim.padding_low() + + pad_dim.edge_padding_low()); + window_dim.set_padding_high(window_dim.padding_high() + + pad_dim.edge_padding_high()); + } + return ReplaceWithNewInstruction( + reduce_window, HloInstruction::CreateReduceWindow( + /*shape=*/reduce_window->shape(), + /*operand=*/operand->mutable_operand(0), + /*init_value=*/reduce_window->mutable_operand(1), + /*window=*/new_window, + /*reduce_computation=*/function)); +} + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), transpose->dimensions().end())) { VLOG(10) << "deleting no-op transpose"; - changed_ = true; - return computation_->ReplaceInstruction(transpose, operand); + return ReplaceInstruction(transpose, operand); } if (HloOpcode::kTranspose == operand->opcode()) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( transpose, HloInstruction::CreateTranspose( transpose->shape(), operand->mutable_operand(0), ComposePermutations(operand->dimensions(), @@ -805,7 +1276,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // bitcasts_ == true. // TODO(cwhipkey): b/31337498, make this layout insensitive. - if (!is_layout_sensitive_) return Status::OK(); + if (!is_layout_sensitive_) { + return Status::OK(); + } const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); @@ -905,9 +1378,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( auto new_rhs = add_bitcast(new_filter_shape, rhs); auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); - changed_ = true; - return computation_->ReplaceInstruction(convolution, - add_bitcast(convolution_shape, dot)); + return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( @@ -921,8 +1392,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp, max_operand, operand, min_operand); - TF_CHECK_OK(computation_->ReplaceWithNewInstruction(root, std::move(clamp))); - changed_ = true; + TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp))); return true; } @@ -995,12 +1465,20 @@ Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, StatusOr AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); - bool changed = std::any_of( - module->computations().begin(), module->computations().end(), - [=](const std::unique_ptr& computation) { - return AlgebraicSimplifierVisitor::Run( - computation.get(), is_layout_sensitive_, valid_bitcast_callback_); - }); + bool changed = false; + // Make a copy of the computations because we may add computations to the + // module, invalidating iteration. + std::vector computations; + for (auto& comp : module->computations()) { + computations.push_back(comp.get()); + } + for (auto& comp : computations) { + if (AlgebraicSimplifierVisitor::Run(comp, is_layout_sensitive_, + valid_bitcast_callback_, + enable_dot_simplification_)) { + changed = true; + } + } XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), after:\n" + module->ToString()); return changed; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index d10d1edc1d2..f8919f0caad 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -35,12 +35,14 @@ class AlgebraicSimplifier : public HloPassInterface { // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. If valid_bitcast_callback - // returns true, then the pass will replace reshapes and tranposes with + // returns true, then the pass will replace reshapes and transposes with // bitcasts. AlgebraicSimplifier(bool is_layout_sensitive, - ValidBitcastCallback valid_bitcast_callback) + ValidBitcastCallback valid_bitcast_callback, + bool enable_dot_simplification = true) : is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + valid_bitcast_callback_(std::move(valid_bitcast_callback)), + enable_dot_simplification_(enable_dot_simplification) {} ~AlgebraicSimplifier() override {} tensorflow::StringPiece name() const override { return "algsimp"; } @@ -51,6 +53,9 @@ class AlgebraicSimplifier : public HloPassInterface { private: bool is_layout_sensitive_; ValidBitcastCallback valid_bitcast_callback_; + + // Enable dot simplication on platforms where it is profitable. + bool enable_dot_simplification_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 8dd94e2c70c..e4368a7bb25 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -23,21 +23,25 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { return [](const Shape&, const Shape&) { return true; }; } + AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } @@ -55,7 +59,53 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); + builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0, 0}))); + HloInstruction* bcast = + builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); + builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); + + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); @@ -77,7 +127,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); @@ -99,7 +149,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); @@ -121,7 +171,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); @@ -149,7 +199,7 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); @@ -157,9 +207,7 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_EQ(root, add); - EXPECT_EQ(root->operand(0), param1); - EXPECT_EQ(root->operand(1), param2); + EXPECT_THAT(root, op::Add(param1, param2)); } // Test that exp(A)/exp(B) is simplified to exp(A-B) @@ -177,19 +225,18 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kDivide); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(op::Exp(param0), op::Exp(param1))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kExp); - EXPECT_EQ(root->operand_count(), 1); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kSubtract); - EXPECT_EQ(root->operand(0)->operand(0), param0); - EXPECT_EQ(root->operand(0)->operand(1), param1); + + EXPECT_THAT(computation->root_instruction(), + op::Exp(op::Subtract(param0, param1))); } // Test that ln(exp(A)) is simplified to A @@ -203,16 +250,16 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kLog); + + EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kParameter); - EXPECT_EQ(root, param0); + + EXPECT_EQ(computation->root_instruction(), param0); } // Test that ln(exp(A)/exp(B)) is simplified to A-B @@ -232,17 +279,17 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kLog); + + EXPECT_THAT(computation->root_instruction(), + op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - EXPECT_EQ(root->operand(0), param0); - EXPECT_EQ(root->operand(1), param1); + + EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); } // Test that pow(A, 0) where A is a scalar is simplified to the scalar @@ -257,13 +304,17 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kConstant); + EXPECT_THAT(root, op::Constant()); EXPECT_EQ(LiteralUtil::GetFirstElement(root->literal()), 1); } @@ -278,13 +329,17 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + EXPECT_THAT(root, op::Broadcast()); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32)) << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); @@ -304,14 +359,16 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kParameter); - EXPECT_EQ(root, param0); + + EXPECT_EQ(computation->root_instruction(), param0); } // Test that pow(A, 2) is simplified to A*A. @@ -325,15 +382,16 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); - EXPECT_EQ(root->operand(0), param0); - EXPECT_EQ(root->operand(1), param0); + + EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); } // Test that pow(A, -1) is simplified to 1/A. @@ -347,17 +405,19 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kDivide); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConstant); + EXPECT_THAT(root, op::Divide(op::Constant(), param0)); EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), 1); - EXPECT_EQ(root->operand(1), param0); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -374,14 +434,17 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); auto computation = builder.Build(); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - HloInstruction* root = module->entry_computation()->root_instruction(); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(op::Broadcast(op::Reshape(op)))); + HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - root = module->entry_computation()->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kParameter); + + EXPECT_THAT(module->entry_computation()->root_instruction(), op); } // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. @@ -392,85 +455,16 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); -} - -TEST_F(AlgebraicSimplifierTest, ConvertF32ToS64) { - HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); - - auto module = MakeUnique(TestName()); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), - 42); -} - -TEST_F(AlgebraicSimplifierTest, ConvertS64ToF32) { - HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - - auto module = MakeUnique(TestName()); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), - 42.0f); -} - -TEST_F(AlgebraicSimplifierTest, ConvertF32ArrayToS64Array) { - HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({42.0f, 19.0f}))); - builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); - - auto module = MakeUnique(TestName()); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {0}), - 42); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {1}), - 19); + EXPECT_THAT(computation->root_instruction(), input); } // Test that copies are removed. @@ -479,19 +473,125 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* copy = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(copy, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(param0, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), param0); +} + +// Test that unary concatenates are removed. +TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { + Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + builder.AddInstruction( + HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), param0); +} + +// Test that empty operands of concatenates are removed. +TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { + const int kParamLength = 100; + Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r1f32, "param1")); + HloInstruction* empty_literal = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction* empty_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1})); + Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength}); + builder.AddInstruction(HloInstruction::CreateConcatenate( + result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT( + computation->root_instruction(), + op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Concatenate(param0, param0, param1)); +} + +// Test a concatenate with only empty operands is removed. +TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { + const int kParamLength = 100; + Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* empty_literal = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction* empty_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1})); + Shape result_shape = ShapeUtil::MakeShape(F32, {0}); + builder.AddInstruction(HloInstruction::CreateConcatenate( + result_shape, {empty_literal, empty_slice}, 0)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Concatenate(empty_literal, empty_slice)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(computation->root_instruction(), empty_literal); +} + +// Test that concat with a scalar broadcast becomes a pad. +TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { + Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r1f32, param1, {})); + builder.AddInstruction(HloInstruction::CreateConcatenate( + param0->shape(), {broadcast, param0}, 0)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); } // Test that a simplification which changes layouts is not performed if layout @@ -504,21 +604,21 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); // Set to different layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - EXPECT_EQ(copy, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Copy has not been removed. - EXPECT_EQ(copy, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); } // Test that a simplification which preserves layouts is performed if layout @@ -531,21 +631,21 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); // Set to same layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - EXPECT_EQ(copy, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Copy has been removed. - EXPECT_EQ(param0, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), param0); } // Test that a reshape which could be replaced with a bitcast is not if @@ -563,17 +663,17 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { *reshape->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(reshape, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. - EXPECT_EQ(reshape, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } // Test transforming reshapes to bitcasts under various conditions. @@ -609,25 +709,48 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { builder.AddInstruction(HloInstruction::CreateTuple( {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(transformable_reshape, computation->root_instruction()->operand(0)); - EXPECT_EQ(dimensions_wrong_reshape, - computation->root_instruction()->operand(1)); - EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2)); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(transformable_reshape, dimensions_wrong_reshape, + layout_wrong_reshape)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + simplifier.Run(module.get()).ValueOrDie(); // Verify that only the first reshape is replaced. - EXPECT_NE(transformable_reshape, computation->root_instruction()->operand(0)); - EXPECT_EQ(HloOpcode::kBitcast, - computation->root_instruction()->operand(0)->opcode()); - EXPECT_EQ(dimensions_wrong_reshape, - computation->root_instruction()->operand(1)); - EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2)); + EXPECT_THAT( + computation->root_instruction(), + op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); +} + +TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param")); + HloInstruction* movable_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), + HloOpcode::kMaximum, movable_reshape, zero)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Reshape(param), zero)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + + simplifier.Run(module.get()).ValueOrDie(); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Maximum(param, zero))); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { @@ -644,16 +767,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_EQ(2, computation->instruction_count()); - EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { @@ -670,16 +794,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({3, 1, 2, 0}); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_EQ(2, computation->instruction_count()); - EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { @@ -692,23 +817,47 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 1, 2}), param0)); - HloInstruction* reshape2 = - builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(reshape2, computation->root_instruction()); - EXPECT_EQ(reshape1, computation->root_instruction()->operand(0)); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode()); - EXPECT_EQ(HloOpcode::kParameter, - computation->root_instruction()->operand(0)->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); +} + +TEST_F(AlgebraicSimplifierTest, CopiesMerged) { + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(F32, {2, 2, 2}), + "param0")); + + HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), + HloOpcode::kCopy, param0)); + + builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), + HloOpcode::kCopy, copy1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { @@ -721,25 +870,21 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0})); - HloInstruction* transpose2 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(transpose2, computation->root_instruction()); - EXPECT_EQ(transpose1, computation->root_instruction()->operand(0)); + EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kTranspose, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); EXPECT_EQ(std::vector({2, 1, 0}), computation->root_instruction()->dimensions()); - EXPECT_EQ(HloOpcode::kParameter, - computation->root_instruction()->operand(0)->opcode()); } // Test merging reshape and broadcast. @@ -752,16 +897,17 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Broadcast(op::Reshape(param0))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); - EXPECT_EQ(HloOpcode::kParameter, - computation->root_instruction()->operand(0)->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } // Test merging broadcast and reshape. @@ -774,16 +920,17 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param0))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); - EXPECT_EQ(HloOpcode::kParameter, - computation->root_instruction()->operand(0)->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { @@ -795,12 +942,18 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); - auto module = MakeUnique(TestName()); - module->AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { @@ -812,15 +965,19 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); - EXPECT_MATCH(computation->root_instruction()->dimensions(), - testing::VectorMatcher({3})); + + EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction()->dimensions(), + ::testing::ElementsAre(3)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { @@ -832,18 +989,21 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); + + EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); const std::vector broadcast_dims = computation->root_instruction()->dimensions(); EXPECT_EQ(1, broadcast_dims.size()); - EXPECT_TRUE(broadcast_dims[0] == 1 || broadcast_dims[0] == 2 || - broadcast_dims[3] == 3); + EXPECT_THAT(broadcast_dims[0], ::testing::AnyOf(1, 2, 3)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { @@ -855,12 +1015,18 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); - auto module = MakeUnique(TestName()); - module->AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { @@ -871,7 +1037,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); PaddingConfig no_padding; - for (auto i = 0; i < 2; ++i) { + for (int i = 0; i < 2; ++i) { auto dimension = no_padding.add_dimensions(); dimension->set_edge_padding_low(0); dimension->set_edge_padding_high(0); @@ -883,10 +1049,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_EQ(1, computation->instruction_count()); + + EXPECT_THAT(computation->root_instruction(), param); } TEST_F(AlgebraicSimplifierTest, NegativePadding) { @@ -901,7 +1070,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { PaddingConfig padding; int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {2, -3}; - for (auto i = 0; i < 2; ++i) { + for (int i = 0; i < 2; ++i) { auto dimension = padding.add_dimensions(); dimension->set_edge_padding_low(low_padding[i]); dimension->set_edge_padding_high(high_padding[i]); @@ -926,18 +1095,14 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { return false; }; - EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(computation->root_instruction(), pad); + EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_EQ(4, computation->instruction_count()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kSlice); - const HloInstruction* root_operand = - computation->root_instruction()->operand(0); - EXPECT_EQ(root_operand->opcode(), HloOpcode::kPad); - EXPECT_FALSE(has_negative_padding(root_operand)); + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); + EXPECT_FALSE( + has_negative_padding(computation->root_instruction()->operand(0))); } TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { @@ -951,10 +1116,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_EQ(1, computation->instruction_count()); + + EXPECT_THAT(computation->root_instruction(), param); } TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { @@ -966,15 +1134,18 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param")); builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, - /*limit_indices=*/{dim0, dim1})); + /*limit_indices=*/{dim0, dim1}, /*slices=*/{1, 1})); HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_EQ(1, computation->instruction_count()); + + EXPECT_THAT(computation->root_instruction(), param); } TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { @@ -1210,21 +1381,21 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, param0, min_value)); - HloInstruction* max = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root, max); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Minimum(param0, min_value), max_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - ASSERT_EQ(root->opcode(), HloOpcode::kClamp); - EXPECT_EQ(root->operand(0), max_value); - EXPECT_EQ(root->operand(1), param0); - EXPECT_EQ(root->operand(2), min_value); + + EXPECT_THAT(computation->root_instruction(), + op::Clamp(max_value, param0, min_value)); } // Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar @@ -1240,21 +1411,21 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* min = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Maximum(param0, max_value), min_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kClamp); - EXPECT_EQ(root->operand(0), max_value); - EXPECT_EQ(root->operand(1), param0); - EXPECT_EQ(root->operand(2), min_value); + + EXPECT_THAT(computation->root_instruction(), + op::Clamp(max_value, param0, min_value)); } // Test that min(max(A, x), y) is transformed to clamp(x, A, y) for @@ -1271,21 +1442,21 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* min = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Maximum(param0, max_value), min_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kClamp); - EXPECT_EQ(root->operand(0), max_value); - EXPECT_EQ(root->operand(1), param0); - EXPECT_EQ(root->operand(2), min_value); + + EXPECT_THAT(computation->root_instruction(), + op::Clamp(max_value, param0, min_value)); } // Test that min(max(A, non-constant1), non-constant2) is not canonicalized to @@ -1301,17 +1472,21 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { HloInstruction::CreateParameter(2, r0f32, "param2")); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* min = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Maximum(param0, max_value), min_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Maximum(param0, max_value), min_value)); } // Test that min(f(max(A, constant1)), constant2) is not transformed to @@ -1329,18 +1504,23 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { r0f32, HloOpcode::kMaximum, param0, max_value)); HloInstruction* fmax = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value)); - HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( + builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, fmax, min_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), + min_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), + min_value)); } // Test that slice(broadcast(/*scalar value*/)) simplifies to a single @@ -1359,7 +1539,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( - slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6})); + slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); @@ -1377,8 +1557,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - EXPECT_EQ(scalar_param, root->operand(0)); + EXPECT_THAT(root, op::Broadcast(scalar_param)); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); } @@ -1415,10 +1594,143 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - EXPECT_EQ(forty_two, root->operand(0)); + EXPECT_THAT(root, op::Broadcast(forty_two)); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); } +// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). +TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Create operand to the pad. + HloInstruction* operand = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0")); + + // Create the pad. + PaddingConfig padding = MakeNoPaddingConfig(4); + padding.mutable_dimensions(1)->set_edge_padding_low(1); + padding.mutable_dimensions(3)->set_edge_padding_high(2); + + HloInstruction* pad_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); + + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module.AddEmbeddedComputation(builder.Build()); + } + + // Create the reduce-window. + Window window; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + auto* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(10); + dim->set_padding_high(100); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + const Shape reduce_window_shape = + ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); + HloInstruction* reduce_init_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + reduce_window_shape, pad, reduce_init_value, window, + add_computation)); + + // Build the computation and run the simplifier. + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, reduce_window); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + // Running simplification again should not result in any further changes. + ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + + // Verify the result + root = computation->root_instruction(); + EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant())); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) + << ShapeUtil::HumanString(root->shape()) << " vs " + << ShapeUtil::HumanString(reduce_window_shape); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(1).padding_low(), 11); + EXPECT_EQ(root->window().dimensions(2).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(3).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(1).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(2).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); +} + +TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { + HloComputation::Builder builder(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1}); + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + builder.AddInstruction( + HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(a, root); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); +} + +TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { + // Dots add computations to the parent module. Test that, when the HloModule's + // computations are updated, then iterator invalidation doesn't occur + // when running on subsequent computations. + Shape r1f32 = ShapeUtil::MakeShape(F32, {1}); + HloComputation::Builder builder(TestName() + ".Dot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kDot, x, y)); + std::unique_ptr dot_computation(builder.Build()); + + HloComputation::Builder call_builder(TestName() + ".Call"); + HloInstruction* zero = call_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0.0f}))); + HloInstruction* one = call_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0f}))); + builder.AddInstruction( + HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); + + auto module = CreateNewModule(); + module->AddEmbeddedComputation(std::move(dot_computation)); + module->AddEntryComputation(call_builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index a123213401d..ad2fee2d39a 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -64,8 +64,9 @@ GlobalDataHandle AllocationTracker::RegisterInternal( auto& allocation = FindOrDie(handle_to_allocation_, handle); int ref_count = allocation->ref_count(); CHECK_GT(ref_count, 0); - VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count + 1; - allocation->increment_ref_count(); + VLOG(2) << "ref_count: " << ref_count << " -> " << + (ref_count + initial_ref_count); + allocation->increment_ref_count(initial_ref_count); } else { handle = next_handle_++; VLOG(2) << "ref_count: " << initial_ref_count; @@ -136,7 +137,7 @@ tensorflow::Status AllocationTracker::DeallocateShape( TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size()) << "tuple has unexpected number of elements: " << elements.size() << " != " << ShapeUtil::TupleElementCount(shape); - for (int i = 0; i < elements.size(); ++i) { + for (size_t i = 0; i < elements.size(); ++i) { VLOG(2) << "recursing onto the tuple elements"; TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i], shape.tuple_shapes(i), @@ -170,6 +171,7 @@ StatusOr> AllocationTracker::DeconstructTuple( executor, allocation->device_memory(), allocation->shape())); std::vector element_handles; + element_handles.reserve(element_bases.size()); for (int i = 0; i < element_bases.size(); ++i) { element_handles.push_back(RegisterInternal( allocation->backend(), allocation->device_ordinal(), element_bases[i], diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index e0076800162..ebbf35b6fe8 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -63,10 +63,10 @@ class Allocation { CHECK_GE(ref_count_, 0); return ref_count_; } - void increment_ref_count() { + void increment_ref_count(int inc) { CHECK_GT(ref_count_, 0); - CHECK_LT(ref_count_, INT_MAX); - ++ref_count_; + CHECK_LE(ref_count_, INT_MAX - inc); + ref_count_ += inc; } void decrement_ref_count() { CHECK_GT(ref_count_, 0); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 7452a7b6965..66d54ad3802 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -41,13 +41,39 @@ namespace se = ::perftools::gputools; namespace xla { +BackendOptions& BackendOptions::set_platform( + perftools::gputools::Platform* platform) { + platform_ = platform; + return *this; +} + +perftools::gputools::Platform* BackendOptions::platform() const { + return platform_; +} + +BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) { + number_of_replicas_ = number_of_replicas; + return *this; +} + +int BackendOptions::number_of_replicas() const { return number_of_replicas_; } + +BackendOptions& BackendOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int BackendOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct Backend::EigenThreadPoolWrapper { - explicit EigenThreadPoolWrapper() - : pool(new tensorflow::thread::ThreadPool( - tensorflow::Env::Default(), "XLAEigen", - tensorflow::port::NumSchedulableCPUs())), + explicit EigenThreadPoolWrapper(const int num_threads) + : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(), + "XLAEigen", num_threads)), wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), device(new Eigen::ThreadPoolDevice(wrapper.get(), wrapper->NumThreads())) {} @@ -58,20 +84,21 @@ struct Backend::EigenThreadPoolWrapper { }; /* static */ StatusOr> Backend::CreateBackend( - perftools::gputools::Platform* platform, int64 replica_count) { + const BackendOptions& options) { + int64 replica_count = options.number_of_replicas(); if (replica_count == -1) { legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); replica_count = flags->xla_replicas; } + perftools::gputools::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto stream_executors, PlatformUtil::GetStreamExecutors(platform)); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); - std::unique_ptr backend(new Backend( - replica_count, platform, compiler, stream_executors, transfer_manager)); - TF_RETURN_IF_ERROR(backend->PoolStreams(kInitialStreamsToPool, - backend->default_stream_executor())); + std::unique_ptr backend( + new Backend(replica_count, platform, compiler, stream_executors, + transfer_manager, options.intra_op_parallelism_threads())); return std::move(backend); } @@ -79,51 +106,36 @@ struct Backend::EigenThreadPoolWrapper { Backend::CreateDefaultBackend() { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetDefaultPlatform()); - return CreateBackend(platform); + BackendOptions backend_options; + backend_options.set_platform(platform); + return CreateBackend(backend_options); } -tensorflow::Status Backend::PoolStreams(int n, se::StreamExecutor* executor) { - std::vector> primed; - for (int i = 0; i < n; ++i) { - TF_ASSIGN_OR_RETURN(auto stream, AcquireStream(executor)); - primed.emplace_back(std::move(stream)); - } - for (int i = 0; i < n; ++i) { - ReleaseStream(std::move(primed.back())); - primed.pop_back(); - } - return tensorflow::Status::OK(); +StatusOr Backend::BorrowStream(int device_ordinal) { + TF_ASSIGN_OR_RETURN(auto exec, stream_executor(device_ordinal)); + return BorrowStream(exec); } -StatusOr> Backend::AcquireStream( - perftools::gputools::StreamExecutor* executor) { - tensorflow::mutex_lock lock(mutex_); - auto& cached_streams = cached_streams_[executor]; - if (!cached_streams.empty()) { - auto result = std::move(cached_streams.back()); - cached_streams.pop_back(); - return std::move(result); +StatusOr Backend::BorrowStream( + se::StreamExecutor* executor) { + tensorflow::mutex_lock l(mu_); + if (0 == stream_pools_.count(executor)) { + stream_pools_.emplace(std::piecewise_construct, + std::forward_as_tuple(executor), + std::forward_as_tuple([executor]() { + auto stream = MakeUnique(executor); + stream->Init(); + return stream; + })); } - - auto stream = MakeUnique(executor); - if (!stream->Init().ok()) { - return InternalError("failed to initialize stream"); - } - return std::move(stream); -} - -void Backend::ReleaseStream( - std::unique_ptr stream) { - tensorflow::mutex_lock lock(mutex_); - auto& streams = cached_streams_[stream->parent()]; - streams.emplace_back(std::move(stream)); + return stream_pools_.at(executor).Allocate(); } Backend::Backend( int64 replica_count, perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager) + TransferManager* transfer_manager, int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), @@ -153,7 +165,11 @@ Backend::Backend( inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( tensorflow::Env::Default(), "xla_inter_op", tensorflow::port::NumSchedulableCPUs())); - intra_op_thread_pool_wrapper_.reset(new EigenThreadPoolWrapper()); + const int num_threads = intra_op_parallelism_threads > 0 + ? intra_op_parallelism_threads + : tensorflow::port::NumSchedulableCPUs(); + intra_op_thread_pool_wrapper_.reset( + new EigenThreadPoolWrapper(num_threads)); } } @@ -199,10 +215,19 @@ tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { - if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr; + if (intra_op_thread_pool_wrapper_ == nullptr) { + return nullptr; + } return intra_op_thread_pool_wrapper_->device.get(); } +tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { + if (intra_op_thread_pool_wrapper_ == nullptr) { + return nullptr; + } + return intra_op_thread_pool_wrapper_->pool.get(); +} + StatusOr Backend::stream_executor( int device_ordinal) const { if (device_ordinal < 0 || diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index db482c09ae2..e0b15dc43f2 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -33,29 +34,50 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" namespace Eigen { -class ThreadPoolDevice; +struct ThreadPoolDevice; } namespace xla { +// Options to configure the backend when it is created. +class BackendOptions { + public: + // Set the platform backing the backend, or nullptr for the default platform. + BackendOptions& set_platform(perftools::gputools::Platform* platform); + perftools::gputools::Platform* platform() const; + + // Set the number of replicas to use when compiling replicated + // programs. The default is -1 meaning that the value is read from + // the xla_replicas flag. + BackendOptions& set_number_of_replicas(int number_of_replicas); + int number_of_replicas() const; + + // Sets the thread pool size for parallel execution of an individual operator. + // The default value of -1 will result in initializing the thread pool with + // the number of threads equal to the number of cores in the system. + BackendOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + + private: + perftools::gputools::Platform* platform_ = nullptr; + int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; +}; + // Class which encapsulates an XLA backend. It includes everything necessary // to compile and execute computations on a particular platform. // // It also offers a pooling API for creation/use of initialized streams: // -// std::unique_ptr stream = -// backend->AcquireStream().ConsumeValueOrDie(); -// // ... use stream ... -// backend->ReleaseStream(std::move(stream)); +// StreamPtr stream = backend->BorrowStream().ConsumeValueOrDie(); class Backend { public: - // The number of streams we create for the pool at initialization time. - static constexpr int kInitialStreamsToPool = 8; + using StreamPtr = Pool::SmartPtr; // Creates a new backend for the given platform with the given number of - // replicas. A value of -1 means to use the flag value. + // replicas. static StatusOr> CreateBackend( - perftools::gputools::Platform* platform, int64 replica_count = -1); + const BackendOptions& options); // Creates a backend for the default platform. The default platform is defined // in PlatformUtil. @@ -108,22 +130,19 @@ class Backend { return stream_executors_[0]; } - // Primes the internal pool of streams for AcquireStream/ReleaseStream with n - // initialized stream instances. - tensorflow::Status PoolStreams(int n, - perftools::gputools::StreamExecutor* executor); - - // Acquires a stream for use by the caller, either by grabbing it from an + // Borrows a stream for use by the caller, either by grabbing it from an // internal pool, or by constructing/initializating it, and returns the result // to the caller. - // - // TODO(b/32989582): Return std::unique_ptr with custom deleter. - StatusOr> AcquireStream( + StatusOr BorrowStream(int device_ordinal); + StatusOr BorrowStream( perftools::gputools::StreamExecutor* executor); - // Releases a stream from the caller to the internal pool, for use with the - // paired AcquireStream above. - void ReleaseStream(std::unique_ptr stream); + // Returns a function to borrow a stream, as `BorrowStream` above does. + // Purely for convenience, the caller could rather make this anonymous + // function itself. + std::function(int)> StreamBorrower() { + return [this](int device_ordinal) { return BorrowStream(device_ordinal); }; + } // Returns whether the given device ordinal of the backend is supported. bool device_ordinal_supported(int device_ordinal) const { @@ -148,6 +167,7 @@ class Backend { // For the host platform, returns the configured eigen threadpool device to be // used for scheduling work. For other platforms, returns NULL. const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; + tensorflow::thread::ThreadPool* eigen_intra_op_thread_pool() const; // Resets the devices associated with this backend. Status ResetDevices(); @@ -158,7 +178,7 @@ class Backend { Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager); + TransferManager* transfer_manager, int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; @@ -170,14 +190,12 @@ class Backend { // Vector of stream executors. stream_executors_[0] is the default executor. std::vector stream_executors_; - // Guards the mutable state in the backend object. - tensorflow::mutex mutex_; + tensorflow::mutex mu_; - // Mapping from stream executor to cached streams, used by - // AcquireStream/ReleaseStream above. + // Mapping from stream executor to stream pools, used by `BorrowStream` above. std::map>> - cached_streams_ GUARDED_BY(mutex_); + Pool> + stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 820c2e95f1a..f91eb0207a2 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -25,13 +25,15 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -39,12 +41,66 @@ limitations under the License. namespace xla { -void BufferAllocation::AddAssignment(const LogicalBuffer& buffer) { - DCHECK(std::find(assigned_buffers_.begin(), assigned_buffers_.end(), - &buffer) == assigned_buffers_.end()) - << "LogicalBuffer " << buffer.ToString() - << " already assigned to allocation " << index(); - assigned_buffers_.push_back(&buffer); +using ::tensorflow::gtl::FlatMap; +using ::tensorflow::gtl::FlatSet; +using ::tensorflow::strings::Appendf; +using ::tensorflow::strings::HumanReadableNumBytes; + +size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { + uint64 h = std::hash()(s.index()); + h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); + h = tensorflow::Hash64Combine(h, std::hash()(s.size())); + return h; +} + +string BufferAllocation::Slice::ToString() const { + return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_, + ", size:", size_, "}"); +} + +BufferAllocation::Slice BufferAllocation::GetSlice( + const LogicalBuffer& buffer) const { + const OffsetSize os = FindOrDie(assigned_buffers_, &buffer); + return Slice(this, os.offset, os.size); +} + +void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, + int64 size) { + CHECK(assigned_buffers_.count(&buffer) == 0) + << "LogicalBuffer " << buffer << " already assigned to allocation " + << index_; + CHECK_LE(offset, size_) << "LogicalBuffer " << buffer + << " offset out of range"; + CHECK_LE(offset + size, size_) + << "LogicalBuffer " << buffer << " size out of range"; + CHECK_EQ(buffer.color(), color()) + << "Buffer color " << buffer.color() + << " does not match allocation color " << color() << "."; + OffsetSize offset_size; + offset_size.offset = offset; + offset_size.size = size; + assigned_buffers_.emplace(&buffer, offset_size); +} + +BufferAllocationProto BufferAllocation::ToProto() const { + BufferAllocationProto proto; + proto.set_index(index_); + proto.set_size(size_); + proto.set_is_thread_local(is_thread_local_); + proto.set_is_reusable(is_reusable_); + proto.set_color(color_.value()); + if (is_entry_computation_parameter_) { + proto.set_is_entry_computation_parameter(true); + proto.set_parameter_number(parameter_number_); + } + proto.set_maybe_live_out(maybe_live_out_); + for (const auto& buffer_offset_size : assigned_buffers_) { + BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned(); + proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id()); + proto_assigned->set_offset(buffer_offset_size.second.offset); + proto_assigned->set_size(buffer_offset_size.second.size); + } + return proto; } string BufferAllocation::ToString() const { @@ -52,19 +108,38 @@ string BufferAllocation::ToString() const { tensorflow::strings::StrAppend( &output, tensorflow::strings::Printf("allocation %lld: %p, size %lld", index_, this, size())); + if (color().value() != 0) { + tensorflow::strings::StrAppend(&output, ", color ", color().value()); + } if (is_entry_computation_parameter()) { tensorflow::strings::StrAppend(&output, ", parameter ", parameter_number()); } if (is_thread_local()) { tensorflow::strings::StrAppend(&output, ", thread-local"); } + if (maybe_live_out()) { + tensorflow::strings::StrAppend(&output, ", maybe-live-out"); + } + if (IsPreallocatedTempBuffer()) { + tensorflow::strings::StrAppend(&output, ", preallocated-temp"); + } tensorflow::strings::StrAppend(&output, ":\n"); - for (const auto& buffer : assigned_buffers()) { + // Dump the assigned buffers ordered by id. + std::vector sorted_buffers; + for (const auto& buffer_offset_size : assigned_buffers_) { + sorted_buffers.push_back(buffer_offset_size.first); + } + std::sort(sorted_buffers.begin(), sorted_buffers.end(), + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); + for (const LogicalBuffer* buffer : sorted_buffers) { + const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); tensorflow::strings::StrAppend( &output, tensorflow::strings::Printf( - " %s::%s : %s\n", buffer->instruction()->parent()->name().c_str(), - buffer->ToString().c_str(), + " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), + offset_size.offset, offset_size.size, ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); } return output; @@ -75,6 +150,11 @@ std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) { return out; } +std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) { + out << s.ToString(); + return out; +} + const PointsToSet& BufferAssignment::GetPointsToSet( const HloInstruction* instruction) const { return points_to_analysis().GetPointsToSet(instruction); @@ -96,22 +176,21 @@ BufferAllocation* BufferAssignment::GetMutableAssignedAllocation( return const_cast(&GetAssignedAllocation(buffer)); } -std::set BufferAssignment::GetAllocations( +std::set BufferAssignment::GetAllSlices( const HloInstruction* instruction, const ShapeIndex& index) const { - std::set allocations; + std::set result; for (const LogicalBuffer* buffer : GetSourceBuffers(instruction, index)) { - if (allocation_index_for_buffer_.count(buffer) > 0) { - allocations.insert( - GetAllocation(allocation_index_for_buffer_.at(buffer))); + if (HasAllocation(*buffer)) { + result.insert(GetAssignedAllocation(*buffer).GetSlice(*buffer)); } } - return allocations; + return result; } const BufferAllocation& BufferAssignment::GetAllocation( BufferAllocation::Index index) const { - CHECK(index >= 0 && index < allocations_.size()) - << "Allocation index " << index << " is out of range."; + CHECK_GE(index, 0); + CHECK_LT(index, allocations_.size()); return allocations_[index]; } @@ -131,71 +210,212 @@ bool BufferAssignment::HasTopLevelAllocation( return false; } -StatusOr BufferAssignment::GetUniqueAllocation( +StatusOr BufferAssignment::GetUniqueSlice( const HloInstruction* instruction, const ShapeIndex& index) const { - const BufferAllocation* allocation = nullptr; + BufferAllocation::Slice result; for (const LogicalBuffer* buffer : GetPointsToSet(instruction).element(index)) { if (HasAllocation(*buffer)) { - if (allocation != nullptr && - *allocation != GetAssignedAllocation(*buffer)) { + const BufferAllocation::Slice slice = + GetAssignedAllocation(*buffer).GetSlice(*buffer); + if (result.allocation() == nullptr) { + result = slice; + } else if (result != slice) { return FailedPrecondition( - "LogicalBuffer allocation for instruction %s at index {%s} cannot " + "BufferAllocation::Slice for instruction %s at index %s cannot " "be determined at compile-time.", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str()); + instruction->name().c_str(), index.ToString().c_str()); } - allocation = &GetAssignedAllocation(*buffer); } } - if (allocation == nullptr) { + if (result.allocation() == nullptr) { return FailedPrecondition( - "instruction %s has no buffer allocation at index {%s}", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str()); + "BufferAllocation::Slice not assigned for instruction %s at index %s", + instruction->name().c_str(), index.ToString().c_str()); } - return allocation; + return result; } -StatusOr BufferAssignment::GetUniqueTopLevelAllocation( +StatusOr BufferAssignment::GetUniqueTopLevelSlice( const HloInstruction* instruction) const { - return GetUniqueAllocation(instruction, /*index=*/{}); + return GetUniqueSlice(instruction, /*index=*/{}); } -StatusOr -BufferAssignment::GetUniqueTopLevelOutputAllocation() const { - return GetUniqueTopLevelAllocation( +bool BufferAssignment::SharesSliceAtIndex( + const HloInstruction* hlo_a, const ShapeIndex& shape_index_a, + const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const { + return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() == + GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie(); +} + +StatusOr +BufferAssignment::GetUniqueTopLevelOutputSlice() const { + return GetUniqueTopLevelSlice( module_->entry_computation()->root_instruction()); } +BufferAllocation* BufferAssignment::NewEmptyAllocation( + int64 size, bool is_thread_local, bool is_reusable, + LogicalBuffer::Color color) { + BufferAllocation::Index index = allocations_.size(); + allocations_.emplace_back(index, size, is_thread_local, is_reusable, color); + BufferAllocation* allocation = &allocations_.back(); + return allocation; +} + BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, int64 size, bool is_thread_local, bool is_reusable) { - BufferAllocation::Index index = allocations_.size(); - allocations_.emplace_back(index, size, is_thread_local, is_reusable); - BufferAllocation* allocation = &allocations_.back(); - AddAssignment(buffer, allocation, /*colocated_buffer=*/false); - allocation_index_for_buffer_[&buffer] = index; + BufferAllocation* allocation = + NewEmptyAllocation(size, is_thread_local, is_reusable, buffer.color()); + AddAssignment(allocation, buffer, /*offset=*/0, size); return allocation; } // Adds an instruction to the set assigned to the given buffer. -void BufferAssignment::AddAssignment(const LogicalBuffer& buffer, - BufferAllocation* allocation, - bool colocated_buffer) { +void BufferAssignment::AddAssignment(BufferAllocation* allocation, + const LogicalBuffer& buffer, int64 offset, + int64 size) { CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer)) << "LogicalBuffer " << buffer << " already has an allocation."; - CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty() || - colocated_buffer) + CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty()) << "Non-reusable allocation already assigned a buffer"; TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); - allocation->AddAssignment(buffer); + allocation->AddAssignment(buffer, offset, size); allocation_index_for_buffer_[&buffer] = allocation->index(); } +// Combines allocations of temporary buffers of the same color into one big +// BufferAllocation. +void BufferAssignment::CombineTempAllocations() { + FlatMap + combined_allocation_map; + + // Move all temp allocations into a single run at the end of the allocations + // vector. + const auto first_temp_it = + std::partition(allocations_.begin(), allocations_.end(), + [](const BufferAllocation& allocation) { + return !allocation.IsPreallocatedTempBuffer(); + }); + + // Walk over the run of temp allocations, collecting the allocations belonging + // to the same color. + if (first_temp_it != allocations_.end()) { + for (auto it = first_temp_it; it != allocations_.end(); ++it) { + const BufferAllocation& temp_allocation = *it; + LogicalBuffer::Color color = temp_allocation.color(); + auto combined_it = combined_allocation_map.find(color); + if (combined_it == combined_allocation_map.end()) { + // We have found the first temp allocation of this color. Collect + // the other temp allocations of the same color into it. + combined_allocation_map.emplace(color, temp_allocation); + continue; + } + + auto* combined_allocation = &combined_it->second; + // Each temp allocation is placed end-to-end, accounting for alignment. + // The offset of each buffer in the combined allocation is computed from + // the base offset of the allocation. + const int64 base = + RoundUpToNearest(combined_allocation->size(), alignment_); + combined_allocation->set_size(base + temp_allocation.size()); + for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) { + const LogicalBuffer* buffer = buffer_offset_size.first; + const int64 offset = buffer_offset_size.second.offset; + const int64 size = buffer_offset_size.second.size; + combined_allocation->AddAssignment(*buffer, base + offset, size); + } + } + // Replace all existing temporary allocations with the new combined + // allocations. + allocations_.erase(first_temp_it, allocations_.end()); + for (auto& combined : combined_allocation_map) { + allocations_.push_back(combined.second); + temp_allocation_total_size_ += combined.second.size(); + } + } + + // Update allocation indices to their new positions. + allocation_index_for_buffer_.clear_no_resize(); + for (size_t index = 0; index < allocations_.size(); ++index) { + BufferAllocation* allocation = &allocations_[index]; + allocation->set_index(index); + for (const auto& buffer_offset_size : allocation->assigned_buffers_) { + const LogicalBuffer* buffer = buffer_offset_size.first; + allocation_index_for_buffer_[buffer] = index; + } + } +} + +Status BufferAssignment::ComputeSummaryStats() { + for (auto& allocation : Allocations()) { + if (allocation.is_entry_computation_parameter()) { + stats_.parameter_allocation_count++; + stats_.parameter_allocation_bytes += allocation.size(); + } + if (allocation.maybe_live_out()) { + stats_.maybe_live_out_allocation_count++; + stats_.maybe_live_out_allocation_bytes += allocation.size(); + } + if (allocation.IsPreallocatedTempBuffer()) { + stats_.preallocated_temp_allocation_count++; + stats_.preallocated_temp_allocation_bytes += allocation.size(); + } + stats_.total_allocation_count++; + stats_.total_allocation_bytes += allocation.size(); + } + + // Only compute total fragmentation if all computations are sequential. + SequentialHloOrdering::HloModuleSequence module_sequence; + for (const auto& computation : module_->computations()) { + const std::vector* sequence = + liveness_->hlo_ordering().SequentialOrder(*computation); + if (sequence != nullptr) { + module_sequence.emplace(computation.get(), *sequence); + } + } + if (module_sequence.size() == module_->computations().size()) { + TF_ASSIGN_OR_RETURN( + const int64 min_size, + MinimumMemoryForSequence(module_sequence, buffer_size_)); + stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; + } + + return Status::OK(); +} + +string BufferAssignment::Stats::ToString() const { + string s; + Appendf(&s, "BufferAssignment stats:\n"); + Appendf(&s, " parameter allocation: %10s\n", + HumanReadableNumBytes(parameter_allocation_bytes).c_str()); + Appendf(&s, " maybe_live_out allocation: %10s\n", + HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str()); + Appendf(&s, " preallocated temp allocation: %10s\n", + HumanReadableNumBytes(preallocated_temp_allocation_bytes).c_str()); + if (preallocated_temp_fragmentation_bytes >= 0) { + const double percent = 100. * preallocated_temp_fragmentation_bytes / + preallocated_temp_allocation_bytes; + Appendf( + &s, " preallocated temp fragmentation: %10s (%.2f%%)\n", + HumanReadableNumBytes(preallocated_temp_fragmentation_bytes).c_str(), + percent); + } + Appendf(&s, " total allocation: %10s\n", + HumanReadableNumBytes(total_allocation_bytes).c_str()); + if (total_fragmentation_bytes >= 0) { + const double percent = + 100. * total_fragmentation_bytes / total_allocation_bytes; + Appendf(&s, " total fragmentation: %10s (%.2f%%)\n", + HumanReadableNumBytes(total_fragmentation_bytes).c_str(), percent); + } + return s; +} + string BufferAssignment::ToString() const { string output; tensorflow::strings::StrAppend(&output, "BufferAssignment:\n"); @@ -205,6 +425,44 @@ string BufferAssignment::ToString() const { return output; } +BufferAssignmentProto BufferAssignment::ToProto() const { + BufferAssignmentProto proto; + // NOTE: TuplePointsToAnalysis state is serialized here in BufferAssigment, + // because we need to do the HasAllocation check for each buffer. Otherwise + // the buffer_size_ call might fail for some backends. + const TuplePointsToAnalysis& points_to_analysis = + liveness_->points_to_analysis(); + for (const auto& buffer : points_to_analysis.logical_buffers()) { + if (HasAllocation(*buffer)) { + LogicalBufferProto proto_buffer = buffer->ToProto(buffer_size_); + proto.add_logical_buffers()->Swap(&proto_buffer); + + // Fill buffer aliases. + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + if (alias.instruction() == buffer->instruction() && + alias.index() == buffer->index()) { + continue; // skip self-aliases + } + BufferAssignmentProto::BufferAlias* proto_alias = + proto.add_buffer_aliases(); + LogicalBufferProto::Location proto_alias_location = + LogicalBuffer::ToLocationProto(*alias.instruction(), alias.index()); + proto_alias->set_source_buffer_id(buffer->id()); + proto_alias->mutable_location()->Swap(&proto_alias_location); + } + } + } + for (const BufferAllocation& allocation : Allocations()) { + BufferAllocationProto proto_allocation = allocation.ToProto(); + proto.add_buffer_allocations()->Swap(&proto_allocation); + } + for (const HeapSimulatorTrace& trace : heap_simulator_traces_) { + *proto.add_heap_simulator_traces() = trace; + } + return proto; +} + namespace { // Walk the call graph of the HLO module and place each computation into either @@ -213,7 +471,7 @@ namespace { // elements in thread_local_computations and global_computations are in post // order (if computation A has an instruction which calls computation B, then A // will appear after B in the vector). -tensorflow::Status GatherComputationsByAllocationType( +Status GatherComputationsByAllocationType( const HloModule* module, std::vector* thread_local_computations, std::vector* global_computations) { @@ -225,8 +483,8 @@ tensorflow::Status GatherComputationsByAllocationType( // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - std::unordered_set thread_local_set; - std::unordered_set global_set; + FlatSet thread_local_set; + FlatSet global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); @@ -263,7 +521,8 @@ tensorflow::Status GatherComputationsByAllocationType( } for (auto& instruction : computation->instructions()) { - for (auto* subcomputation : instruction->MakeCalledComputationsSet()) { + for (HloComputation* subcomputation : + instruction->called_computations()) { switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kWhile: @@ -308,7 +567,7 @@ tensorflow::Status GatherComputationsByAllocationType( // will not appear in either thread_local_set or global_set. We don't bother // assigning buffers for these. } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace @@ -316,39 +575,33 @@ tensorflow::Status GatherComputationsByAllocationType( /* static */ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, bool colocate_related_buffers, - const std::vector* hlos_to_allocate) { - BufferAssigner assigner(std::move(buffer_size), colocate_related_buffers); + LogicalBuffer::SizeFunction buffer_size, int64 alignment, + bool allow_input_output_aliasing, TuplePointsToAnalysis::Colorer colorer) { + BufferAssigner assigner(alignment, allow_input_output_aliasing, + std::move(colorer)); return assigner.CreateAssignment(module, std::move(hlo_ordering), - hlos_to_allocate); -} - -/* static */ -StatusOr> BufferAssigner::Run( - const HloModule* module, std::unique_ptr hlo_ordering, - int64 pointer_size) { - return BufferAssigner::Run(module, std::move(hlo_ordering), - [pointer_size](const LogicalBuffer& buffer) { - return ShapeUtil::IsOpaque(buffer.shape()) - ? 0 - : ShapeUtil::ByteSizeOf( - buffer.shape(), pointer_size); - }, - /*colocate_related_buffers=*/true); + std::move(buffer_size)); } bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, const LogicalBuffer& buffer, BufferAssignment* assignment) { + const LogicalBuffer::SizeFunction& buffer_size = assignment->buffer_size_; + CHECK(!assignment->HasAllocation(buffer)) << "buffer " << buffer << " already has an allocation assigned."; - VLOG(4) << "Trying to assign " << buffer.ToString() - << " to allocation: " << allocation->ToString(); + VLOG(4) << "Trying to assign " << buffer << " to allocation: " << *allocation; - if (buffer_size_(buffer) > allocation->size()) { + if (buffer.color() != allocation->color()) { + VLOG(4) << "Can't assign: buffer has color" << buffer.color() + << " and allocation has color " << allocation->color() << "."; + return false; + } + + if (buffer_size(buffer) > allocation->size()) { VLOG(4) << "Can't assign: buffer is larger than allocation (" - << buffer_size_(buffer) << " > " << allocation->size() << ")"; + << buffer_size(buffer) << " > " << allocation->size() << ")"; return false; } @@ -362,139 +615,198 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, return false; } - for (const LogicalBuffer* assigned_buffer : allocation->assigned_buffers()) { - if (assignment->liveness().MayInterfere(*assigned_buffer, buffer)) { - VLOG(4) << "Can't assign: assignee " << assigned_buffer->ToString() - << " may interfere with " << buffer.ToString(); + for (const auto& buffer_offset_size : allocation->assigned_buffers()) { + const LogicalBuffer& assigned_buffer = *buffer_offset_size.first; + if (assignment->liveness().MayInterfere(assigned_buffer, buffer)) { + VLOG(4) << "Can't assign: assignee " << assigned_buffer + << " may interfere with " << buffer; return false; } + // Copy instruction don't share a buffer with their input operand. + if (buffer.instruction()->IsUserOf(assigned_buffer.instruction()) && + buffer.instruction()->opcode() == HloOpcode::kCopy) { + VLOG(4) << "Can't assign: assignee " << assigned_buffer + << " is used at copy instruction " << buffer; + return false; + } + } + + if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { + HloComputation* entry_computation = + assignment->module_->entry_computation(); + for (auto param : entry_computation->parameter_instructions()) { + for (auto& param_buffer : + assignment->points_to_analysis().GetBuffersDefinedByInstruction( + param)) { + if (assignment->liveness().MayInterfere(*param_buffer, buffer)) { + VLOG(4) << "Can't assign: Parameter interference with result"; + return false; + } + } + } } // If the buffer is live out of the computation then it should only be // assigned a buffer which exactly fits the result to avoid wasting memory // (result buffers can have arbitrary lifetimes). if (assignment->liveness().MaybeLiveOut(buffer) && - allocation->size() != buffer_size_(buffer)) { - VLOG(4) << "Can't assign: buffer " << buffer.ToString() + allocation->size() != buffer_size(buffer)) { + VLOG(4) << "Can't assign: buffer " << buffer << "is live out and size not the same as allocation"; return false; } - assignment->AddAssignment(buffer, allocation, /*colocated_buffer=*/false); + assignment->AddAssignment(allocation, buffer, /*offset=*/0, + buffer_size(buffer)); return true; } -tensorflow::Status BufferAssigner::AssignBuffersForComputation( +Status BufferAssigner::AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet* hlos_to_allocate, - const tensorflow::gtl::FlatSet& colocated_buffers, - const tensorflow::gtl::FlatSet& - colocated_allocations, + const FlatSet& colocated_buffers, + const FlatSet& colocated_allocations, + FlatMap>* + buffers_to_assign_sequentially, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of // size. std::vector sorted_buffers; for (auto& instruction : computation->instructions()) { - if (hlos_to_allocate == nullptr || - hlos_to_allocate->count(instruction.get()) > 0) { - // Add all buffers which this instruction defines. Instruction which don't - // define buffers (eg, bitcast which just forwards a pointer) don't need - // any allocations. - for (const LogicalBuffer* buffer : - assignment->points_to_analysis().GetBuffersDefinedByInstruction( - instruction.get())) { - sorted_buffers.push_back(buffer); - } + // Add all buffers which this instruction defines. Instruction which don't + // define buffers (eg, bitcast which just forwards a pointer) don't need + // any allocations. + for (const LogicalBuffer* buffer : + assignment->points_to_analysis().GetBuffersDefinedByInstruction( + instruction.get())) { + sorted_buffers.push_back(buffer); } } // Generate a post order sort of instructions for sorting of the // LogicalBuffers. - tensorflow::gtl::FlatMap post_order_position; + FlatMap post_order_position; int position = 0; for (auto* instruction : computation->MakeInstructionPostOrder()) { post_order_position.emplace(instruction, position); position++; } + // If there is a sequential instruction ordering, we'll delay assignment of + // temp buffers until after the main assignment loop. + const BufferLiveness& liveness = assignment->liveness(); + const bool has_sequential_order = + liveness.hlo_ordering().SequentialOrder(*computation) != nullptr; + if (has_sequential_order && buffers_to_assign_sequentially != nullptr) { + // Every sequential computation must get an entry in the + // buffers_to_assign_sequentially map, even if we end up with an empty set + // of buffers. This ensures we can correctly determine whether to run + // whole-module heap simulation. + buffers_to_assign_sequentially->emplace(computation, + FlatSet()); + } + // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers // first for simplicity. This means any previously created BufferAllocation is // necessarily large enough to hold the output of the current Buffer in // consideration. // - // As a secondary sorting criteria, use post order position of the HLO - // instruction which defines the buffer. This means an instruction will appear - // after its operands (assuming operands are the same/larger size) enabling - // the important reuse case where an elementwise instruction reuses one of its + // As a secondary sorting criteria, if the instructions are sequentially + // ordered, we assign live-out buffers before others. Note that for sequential + // computations, we'll take temp buffers that can't re-use any allocations and + // assign them via a heap scheduler. By assigning live-out buffers first, we + // increase the odds that temp buffers can re-use an allocation. + // + // As a final tiebreaker use post order position of the HLO instruction which + // defines the buffer. This means an instruction will appear after its + // operands (assuming operands are the same/larger size) enabling the + // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [this, &post_order_position](const LogicalBuffer* a, - const LogicalBuffer* b) { - int64 a_size = buffer_size_(*a); - int64 b_size = buffer_size_(*b); - if (a_size == b_size) { - // For instructions with the same size buffers, sort them in - // post order. - return post_order_position.at(a->instruction()) < - post_order_position.at(b->instruction()); - } else { - // We want the HLOs sorted in reverse order by size so use ">". - return a_size > b_size; + [this, has_sequential_order, &liveness, &post_order_position, + assignment](const LogicalBuffer* a, const LogicalBuffer* b) { + // Primary sort is by decreasing buffer size. + const int64 a_size = assignment->buffer_size_(*a); + const int64 b_size = assignment->buffer_size_(*b); + if (a_size != b_size) { + return a_size > b_size; // use ">" for decreasing size. } + // Otherwise live out buffers come before others, if the + // instructions are sequentially ordered. + if (has_sequential_order) { + const bool a_live_out = liveness.MaybeLiveOut(*a); + const bool b_live_out = liveness.MaybeLiveOut(*b); + if (a_live_out != b_live_out) { + return a_live_out; + } + } + // Final tiebreaker is in instruction post order. + return post_order_position.at(a->instruction()) < + post_order_position.at(b->instruction()); }); // BufferAllocations are necessarily created in decreasing size order. Keep // indices of previously created BufferAllocations in allocation_indices. std::vector allocation_indices; - for (const auto* buffer : sorted_buffers) { - VLOG(3) << "Assigning allocation to: " << buffer->ToString(); + for (const LogicalBuffer* buffer : sorted_buffers) { + VLOG(3) << "Assigning allocation to: " << *buffer; if (colocated_buffers.count(buffer) > 0) { // Colocated buffers are currently assigned in an earlier pass. + VLOG(3) << "Skipping colocated buffer: " << *buffer; continue; } TF_RET_CHECK(!assignment->HasAllocation(*buffer)); - if (buffer->instruction()->opcode() == HloOpcode::kConstant) { + const HloInstruction* instruction = buffer->instruction(); + if (instruction->opcode() == HloOpcode::kConstant) { // No BufferAllocations for constants. // TODO(b/32248867): For consistency, constants should get allocations. + VLOG(3) << "Skipping constant: " << *buffer; continue; } - if (buffer->instruction()->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation()) { + const int64 buffer_size = assignment->buffer_size_(*buffer); + + const bool is_entry_parameter = + instruction->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation(); + if (is_entry_parameter) { // If the LogicalBuffer is part of an external parameter, creates a new // allocation and sets its parameter number. Parameters of non-entry // computations do not need special allocations because they live inside // callers. BufferAllocation* allocation = - assignment->NewAllocation(*buffer, buffer_size_(*buffer), + assignment->NewAllocation(*buffer, buffer_size, /*is_thread_local=*/false, /*is_reusable=*/false); allocation->set_entry_computation_parameter( - buffer->instruction()->parameter_number()); - VLOG(3) << "New allocation for entry computation parameter: " - << buffer->ToString(); + instruction->parameter_number()); + VLOG(3) << "New allocation #" << allocation->index() + << " for entry computation parameter: " << *buffer; continue; } legacy_flags::BufferAssignmentFlags* flags = legacy_flags::GetBufferAssignmentFlags(); if (!flags->xla_enable_buffer_reuse || is_thread_local || - buffer->instruction()->opcode() == HloOpcode::kCustomCall) { + instruction->opcode() == HloOpcode::kCustomCall) { // Custom call operations never have reusable buffers. Also we do not // reuse thread-local buffers for now, because they are dynamically // allocated and their lifetimes are hard to compute. - assignment->NewAllocation(*buffer, buffer_size_(*buffer), is_thread_local, - /*is_reusable=*/false); + BufferAllocation* allocation = assignment->NewAllocation( + *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); + VLOG(3) << "New allocation #" << allocation->index() + << " for thread-local/CustomCall: " << *buffer; continue; } if (ShapeUtil::IsTuple(buffer->shape())) { // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend // assumes longer buffer liveness than indicated by the analysis. - assignment->NewAllocation(*buffer, buffer_size_(*buffer), is_thread_local, - /*is_reusable=*/false); + BufferAllocation* allocation = assignment->NewAllocation( + *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); + VLOG(3) << "New allocation #" << allocation->index() + << " for tuple-shaped buffer: " << *buffer; continue; } @@ -503,23 +815,23 @@ tensorflow::Status BufferAssigner::AssignBuffersForComputation( // (checked in liveness analysis) which are necessarily top-level // array-shaped buffers. if (buffer->IsTopLevel() && !buffer->IsTuple()) { - for (auto* operand : buffer->instruction()->operands()) { + for (auto* operand : instruction->operands()) { bool assigned_operand = false; - for (const auto& operand_allocation : - assignment->GetAllocations(operand, /*index=*/{})) { + for (const auto& operand_slice : + assignment->GetAllSlices(operand, /*index=*/{})) { BufferAllocation* allocation = - assignment->GetMutableAllocation(operand_allocation.index()); + assignment->GetMutableAllocation(operand_slice.index()); if (colocated_allocations.count(allocation->index()) == 0) { // TODO(b/32491382) Colocated buffers are currently assigned in an // earlier pass, and so can break the "increasing allocation size" // invariant in this function (causing this CHECK to fail). However, // the call to MaybeAssignBuffer is safe as it returns false if // allocation.size < buffer.size. - CHECK_GE(allocation->size(), buffer_size_(*buffer)); + CHECK_GE(allocation->size(), buffer_size); } if (MaybeAssignBuffer(allocation, *buffer, assignment)) { - VLOG(3) << "Reusing (operand) allocation for: " - << buffer->ToString(); + VLOG(3) << "Reusing (operand) allocation #" << allocation->index() + << " for: " << *buffer; assigned_operand = true; break; } @@ -546,24 +858,148 @@ tensorflow::Status BufferAssigner::AssignBuffersForComputation( // invariant in this function (causing this CHECK to fail). However, // the call to MaybeAssignBuffer is safe as it returns false if // allocation.size < buffer.size. - CHECK_GE(allocation->size(), buffer_size_(*buffer)); + CHECK_GE(allocation->size(), buffer_size); } if (MaybeAssignBuffer(allocation, *buffer, assignment)) { - VLOG(3) << "Reusing buffer for: " << buffer->ToString(); + VLOG(3) << "Reusing allocation #" << allocation->index() + << " for: " << *buffer; break; } } } + + if (!assignment->HasAllocation(*buffer) && has_sequential_order && + !liveness.MaybeLiveOut(*buffer)) { + // There is a sequential instruction ordering, so we delay assignment of + // temp buffers until after the loop. We do this right before we decide to + // create a new allocation, to ensure we've exhausted all the buffer + // re-use cases above. + // + // Entry parameters and thread local buffers were already handled earlier + // in this loop iteration. See BufferAllocation::IsPreallocatedTempBuffer + // for the definition of temp buffers. + CHECK(!is_entry_parameter) << *buffer; + CHECK(!is_thread_local) << *buffer; + (*buffers_to_assign_sequentially)[computation].insert(buffer); + VLOG(3) << "Delaying assignment of temp buffer: " << *buffer; + continue; + } + if (!assignment->HasAllocation(*buffer)) { - auto* allocation = - assignment->NewAllocation(*buffer, buffer_size_(*buffer), - is_thread_local, /*is_reusable=*/true); - VLOG(3) << "New allocation for: " << buffer->ToString(); + BufferAllocation* allocation = assignment->NewAllocation( + *buffer, buffer_size, is_thread_local, /*is_reusable=*/true); allocation_indices.push_back(allocation->index()); + VLOG(3) << "New allocation #" << allocation->index() + << " for: " << *buffer; } } - return tensorflow::Status::OK(); + + return Status::OK(); +} + +FlatMap, + LogicalBuffer::Color::Hasher> +BufferAssigner::SplitBuffersByColor( + const FlatSet& buffers) { + FlatMap, + LogicalBuffer::Color::Hasher> + color_map; + for (auto buffer : buffers) { + color_map[buffer->color()].insert(buffer); + } + return color_map; +} + +Status BufferAssigner::AssignBuffersWithSequentialOrdering( + const FlatMap>& + buffers_to_assign_sequentially, + bool run_whole_module_heap_simulation, BufferAssignment* assignment) { + // Run the sequence of instructions through the heap simulator. The heuristic + // that seems to give the best results is lazy-best-fit, with all runs of + // alloc / free calls sorted in decreasing size order. + const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); + if (run_whole_module_heap_simulation) { + // Run the heap simulation over the whole module. This reduces memory usage, + // since buffers for kCall and kWhile sub-computations are only live for the + // duration of their calling instructions. + VLOG(1) << "Running whole-module heap simulation"; + SequentialHloOrdering::HloModuleSequence module_sequence; + FlatSet all_buffers_to_assign; + for (const auto& pair : buffers_to_assign_sequentially) { + const HloComputation* computation = pair.first; + const FlatSet& buffers_to_assign = pair.second; + const std::vector* instruction_sequence = + hlo_ordering.SequentialOrder(*computation); + CHECK(instruction_sequence != nullptr) << computation->name(); + module_sequence[computation] = *instruction_sequence; + all_buffers_to_assign.insert(buffers_to_assign.begin(), + buffers_to_assign.end()); + } + auto color_map = SplitBuffersByColor(all_buffers_to_assign); + for (auto& single_colored_set : color_map) { + VLOG(2) << "Simulating heap for color " << single_colored_set.first; + TF_ASSIGN_OR_RETURN( + const HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique( + MakeUnique(alignment_)), + assignment->module(), module_sequence, + assignment->points_to_analysis(), + assignment->buffer_size_, + &single_colored_set.second)); + AssignBuffersFromHeapSimulator(result, assignment, + single_colored_set.first); + } + } else { + // Run the heap-simulation on a per-computation basis. Buffers for + // sub-computations are assigned disjoint BufferAllocations, assuming the + // worst-case that they may all be live concurrently. + VLOG(1) << "Running per-computation heap simulation"; + for (const auto& pair : buffers_to_assign_sequentially) { + const HloComputation* computation = pair.first; + const FlatSet& buffers_to_assign = pair.second; + const std::vector* instruction_sequence = + hlo_ordering.SequentialOrder(*computation); + CHECK(instruction_sequence != nullptr) << computation->name(); + auto color_map = SplitBuffersByColor(buffers_to_assign); + for (auto& single_colored_set : color_map) { + VLOG(2) << "Simulating heap for color " << single_colored_set.first; + TF_ASSIGN_OR_RETURN( + const HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique( + MakeUnique(alignment_)), + *computation, *instruction_sequence, + assignment->points_to_analysis(), + assignment->buffer_size_, + &single_colored_set.second)); + AssignBuffersFromHeapSimulator(result, assignment, + single_colored_set.first); + } + } + } + return Status::OK(); +} + +void BufferAssigner::AssignBuffersFromHeapSimulator( + const HeapSimulator::Result& result, BufferAssignment* assignment, + LogicalBuffer::Color color) { + if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) { + assignment->stats_.preallocated_temp_fragmentation_bytes = + result.fragmentation_size; + } else { + assignment->stats_.preallocated_temp_fragmentation_bytes += + result.fragmentation_size; + } + + BufferAllocation* allocation = assignment->NewEmptyAllocation( + result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color); + for (const auto& buffer_chunk : result.chunk_map) { + const LogicalBuffer& buffer = *buffer_chunk.first; + const HeapSimulator::Chunk& chunk = buffer_chunk.second; + assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); + } + + assignment->heap_simulator_traces_.push_back(result.debug_trace); } // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining @@ -586,12 +1022,14 @@ void BufferAssigner::AddSetToColocatedBufferSets( } // Find existing sets that overlap with at least one buffer from the - // colocated_set. + // colocated_set. The resulting 'overlap_set_indices' will have at most + // colocated_buffer_sets->size() entries, and will be in increasing order. std::vector overlap_set_indices; - for (const LogicalBuffer* buffer : colocated_set) { - for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { + for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { + for (const LogicalBuffer* buffer : colocated_set) { if ((*colocated_buffer_sets)[index].count(buffer) > 0) { overlap_set_indices.push_back(index); + break; } } } @@ -622,40 +1060,154 @@ void BufferAssigner::AddSetToColocatedBufferSets( } } +// Conceptually the same as AddSetToColocatedBufferSets, but specific to the +// colocated buffers for while instructions. 'colocated_set' contains the +// buffers for a single while instruction that must be colocated. The idea here +// is to apply a memory-saving heuristic for separate while instructions whose +// buffers are disjoint in liveness, by using the colocation mechanism to force +// buffer sharing. This often reduces memory for multi-layer RNNs. +// +// TODO(b/32491382): We should be able to remove this heuristic after we +// implement module-level liveness analysis, which would let us directly detect +// buffer sharing opportunities between the while instruction buffer and the +// buffers from the predicate and body computation, as well as sharing across +// different while instructions. +void BufferAssigner::AddWhileSetToColocatedBufferSets( + const std::vector& colocated_set, + const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const HloComputation& computation, const BufferLiveness& buffer_liveness, + const LogicalBuffer::SizeFunction& buffer_size, + std::vector* colocated_buffer_sets) { + CHECK(!colocated_set.empty()); + const TuplePointsToAnalysis& points_to_analysis = + buffer_liveness.points_to_analysis(); + + // Parallel while loops cannot safely share colocated buffer sets. + if (buffer_liveness.hlo_ordering().SequentialOrder(computation) == nullptr) { + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + return; + } + + // Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets + // are added in postorder over computations and instructions. + const int64 init_buffer_size = buffer_size(*while_init_buffer); + for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) { + const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i]; + + // Skip predecessor sets not associated with while loops. + if (std::all_of(predecessor_set.begin(), predecessor_set.end(), + [](const LogicalBuffer* buffer) { + return buffer->instruction()->opcode() != + HloOpcode::kWhile; + })) { + continue; + } + + // Skip predecessor sets already associated with 'while_hlo'. + if (std::any_of(predecessor_set.begin(), predecessor_set.end(), + [&while_hlo](const LogicalBuffer* buffer) { + return buffer->instruction() == while_hlo; + })) { + continue; + } + + // Build vector of predecessor while result and init buffers, which are + // checked for liveness interference below. We must check both the result + // and init buffers because they're aliased together, but + // TuplePointsToAnalysis is unaware of this aliasing. + std::vector predecessor_while_buffers; + for (const LogicalBuffer* buffer : predecessor_set) { + const HloInstruction* instruction = buffer->instruction(); + if (instruction->opcode() == HloOpcode::kWhile && + buffer_size(*buffer) == init_buffer_size && + instruction->parent() == &computation) { + predecessor_while_buffers.push_back(buffer); + // Add the init buffer at the same index, which must also exist in the + // predecessor set, and must be unambiguous. + const PointsToSet& init_points_to = + points_to_analysis.GetPointsToSet(instruction->operand(0)); + const std::vector& init_buffers = + init_points_to.element(buffer->index()); + CHECK_EQ(init_buffers.size(), 1); + CHECK_GT(predecessor_set.count(init_buffers[0]), 0); + predecessor_while_buffers.push_back(init_buffers[0]); + } + } + if (predecessor_while_buffers.empty()) { + continue; + } + + // Skip predecessor set if the live range of any predecessor buffers + // overlaps with 'while_init_buffer'. Note that tuple element buffer + // forwarding can cause the same buffer to appear on both sides of the + // interference comparison below. + if (std::any_of( + predecessor_while_buffers.begin(), predecessor_while_buffers.end(), + [while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) { + return while_init_buffer->id() != buffer->id() && + buffer_liveness.MayInterfere(*while_init_buffer, *buffer); + })) { + continue; + } + + // All our checks have passed; merge 'predecessor_set' with 'colocated_set', + // and add the merged set to 'colocated_buffer_sets'. This forces the + // colocation of buffers across different while instructions. + FlatSet unique; + unique.insert(predecessor_set.begin(), predecessor_set.end()); + unique.insert(colocated_set.begin(), colocated_set.end()); + std::vector merged_set(unique.begin(), unique.end()); + AddSetToColocatedBufferSets(merged_set, colocated_buffer_sets); + return; + } + + // Failed to merge into predecessor set; add 'colocated_set' as-is. + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); +} + namespace { + // Checks that points-to set of 'instruction' is unambiguous and distinct // (ensured by CopyInsertion), then adds the buffer from the points-to set at // 'index' to 'colocated_set'. -void AddBufferToColocatedSet(const HloInstruction* instruction, - const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { // CopyInsertion ensures root points-to set is unambiguous and distinct. const auto& points_to = points_to_analysis.GetPointsToSet(instruction); CHECK(!points_to.IsAmbiguous()); CHECK(points_to.IsDistinct()); colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); } + } // namespace // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated // in the same allocation (currently just supports kWhile and kCall). void BufferAssigner::BuildColocatedBufferSets( - const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloModule* module, const BufferLiveness& buffer_liveness, + const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets) { - for (auto& computation : module->computations()) { - for (auto& instruction : computation->instructions()) { + const TuplePointsToAnalysis& points_to_analysis = + buffer_liveness.points_to_analysis(); + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + for (const HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { const HloOpcode opcode = instruction->opcode(); if (opcode == HloOpcode::kWhile) { - HloInstruction* while_hlo = instruction.get(); - TF_CHECK_OK(ShapeUtil::ForEachSubshape( + const HloInstruction* while_hlo = instruction; + ShapeUtil::ForEachSubshape( while_hlo->shape(), - [this, while_hlo, &points_to_analysis, colocated_buffer_sets]( + [this, while_hlo, &points_to_analysis, &buffer_liveness, + buffer_size, computation, colocated_buffer_sets]( const Shape& /*subshape*/, const ShapeIndex& index) { std::vector colocated_set; // Add while.init. - AddBufferToColocatedSet(while_hlo->operand(0), index, - points_to_analysis, &colocated_set); + auto* init_buffer = + AddBufferToColocatedSet(while_hlo->operand(0), index, + points_to_analysis, &colocated_set); // Add while.result. AddBufferToColocatedSet(while_hlo, index, points_to_analysis, &colocated_set); @@ -671,13 +1223,15 @@ void BufferAssigner::BuildColocatedBufferSets( AddBufferToColocatedSet( while_hlo->while_body()->root_instruction(), index, points_to_analysis, &colocated_set); - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); - return tensorflow::Status::OK(); - })); + AddWhileSetToColocatedBufferSets( + colocated_set, init_buffer, while_hlo, *computation, + buffer_liveness, buffer_size, colocated_buffer_sets); + }); } else if (opcode == HloOpcode::kCall) { - HloInstruction* call_hlo = instruction.get(); - HloInstruction* root_hlo = call_hlo->to_apply()->root_instruction(); - TF_CHECK_OK(ShapeUtil::ForEachSubshape( + const HloInstruction* call_hlo = instruction; + const HloInstruction* root_hlo = + call_hlo->to_apply()->root_instruction(); + ShapeUtil::ForEachSubshape( call_hlo->shape(), [this, call_hlo, root_hlo, &points_to_analysis, colocated_buffer_sets](const Shape& /*subshape*/, @@ -690,8 +1244,7 @@ void BufferAssigner::BuildColocatedBufferSets( AddBufferToColocatedSet(root_hlo, index, points_to_analysis, &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); - return tensorflow::Status::OK(); - })); + }); } } } @@ -702,23 +1255,43 @@ void BufferAssigner::BuildColocatedBufferSets( void BufferAssigner::AssignColocatedBufferSets( const std::vector& colocated_buffer_sets, BufferAssignment* assignment, - tensorflow::gtl::FlatSet* colocated_buffers, - tensorflow::gtl::FlatSet* colocated_allocations) { + FlatSet* colocated_buffers, + FlatSet* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; + // Set 'entry_parameter_number' if entry param in 'colocated_buffer_set'. + int64 entry_parameter_number = -1; + for (const LogicalBuffer* buffer : colocated_buffer_set) { + const HloInstruction* instruction = buffer->instruction(); + const HloComputation* computation = instruction->parent(); + if (instruction->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation()) { + entry_parameter_number = instruction->parameter_number(); + break; + } + } + for (const LogicalBuffer* buffer : colocated_buffer_set) { if (allocation == nullptr) { // TODO(b/32491382) Avoid current trivial solution of using new // allocations for each colocated buffer set. When liveness has // module-level scope, we can allow buffers to be shared across // computations (in some cases). - allocation = assignment->NewAllocation(*buffer, buffer_size_(*buffer), - /*is_thread_local=*/false, - /*is_reusable=*/true); + allocation = assignment->NewAllocation( + *buffer, assignment->buffer_size_(*buffer), + /*is_thread_local=*/false, /*is_reusable=*/true); + if (entry_parameter_number >= 0) { + // This colocated buffer set contains an entry parameter and other + // logical buffers which use the parameter as read-only in a while + // body computation (which updates in place). + // Set 'entry_computation_parameter' to indicate that it contains + // an entry parameter, and to prevent reuse in MaybeAssignBuffer. + allocation->set_entry_computation_parameter(entry_parameter_number); + } colocated_allocations->insert(allocation->index()); } else { - assignment->AddAssignment(*buffer, allocation, - /*colocated_buffer=*/true); + assignment->AddAssignment(allocation, *buffer, /*offset=*/0, + assignment->buffer_size_(*buffer)); } colocated_buffers->insert(buffer); } @@ -727,121 +1300,88 @@ void BufferAssigner::AssignColocatedBufferSets( StatusOr> BufferAssigner::CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, - const std::vector* hlos_to_allocate) { + LogicalBuffer::SizeFunction buffer_size) { TF_ASSIGN_OR_RETURN(std::unique_ptr liveness, - BufferLiveness::Run(module, std::move(hlo_ordering))); + BufferLiveness::Run(module, std::move(hlo_ordering), + std::move(colorer_))); - std::vector thread_local_computations; - std::vector global_computations; VLOG(1) << "Assigning buffers to module " << module->name(); - if (hlos_to_allocate != nullptr) { - VLOG(3) << "LogicalBuffer assignment restricted to hlos: "; - for (auto hlo : *hlos_to_allocate) { - VLOG(3) << " " << hlo->parent()->name() << "::" << hlo->name(); - } - } - XLA_VLOG_LINES(3, module->ToString()); + XLA_VLOG_LINES(2, module->ToString()); XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); - TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( - module, &thread_local_computations, &global_computations)); - - // Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to - // AssignBuffersForComputation for fast membership testing. - std::unique_ptr> hlo_set; - if (hlos_to_allocate != nullptr) { - hlo_set = MakeUnique>( - hlos_to_allocate->begin(), hlos_to_allocate->end()); - } - // Can't use MakeUnique because BufferAssignment constructor is private. - std::unique_ptr assignment( - new BufferAssignment(module, std::move(liveness))); + std::unique_ptr assignment(new BufferAssignment( + module, std::move(liveness), alignment_, std::move(buffer_size))); // Assign buffers with the tightest constraints first (colocated buffer sets). // Once b/32491382 enables module-level liveness analysis, we may be able // to assign colocated buffers (or at least reuse their allocation for // buffers outside of the set) in AssignBuffersForComputation. - tensorflow::gtl::FlatSet colocated_buffers; - tensorflow::gtl::FlatSet colocated_allocations; - if (colocate_related_buffers_) { - std::vector colocated_buffer_sets; - BuildColocatedBufferSets(module, assignment->points_to_analysis(), - &colocated_buffer_sets); - AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), - &colocated_buffers, &colocated_allocations); - } + FlatSet colocated_buffers; + FlatSet colocated_allocations; + std::vector colocated_buffer_sets; + BuildColocatedBufferSets(module, assignment->liveness(), + assignment->buffer_size_, &colocated_buffer_sets); + AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), + &colocated_buffers, &colocated_allocations); + std::vector thread_local_computations; + std::vector global_computations; + TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( + module, &thread_local_computations, &global_computations)); + + // First assign buffers for global computatations. Temporary buffers for + // sequential computations are collected in 'buffers_to_assign_sequentially'. + FlatMap> + buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/false, hlo_set.get(), - colocated_buffers, colocated_allocations, assignment.get())); + computation, /*is_thread_local=*/false, colocated_buffers, + colocated_allocations, &buffers_to_assign_sequentially, + assignment.get())); } + // Assign buffers with sequential ordering, if any. If all global computations + // are sequential, we can run heap simuation on the whole module, which + // reduces memory usage. + const bool run_whole_module_heap_simulation = + buffers_to_assign_sequentially.size() == global_computations.size(); + TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( + buffers_to_assign_sequentially, run_whole_module_heap_simulation, + assignment.get())); + + // Now assign buffers for thread-local computations. All LogicalBuffers get + // their own BufferAllocation. for (auto* computation : thread_local_computations) { TF_RET_CHECK(computation != module->entry_computation()); TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/true, hlo_set.get(), colocated_buffers, - colocated_allocations, assignment.get())); + computation, /*is_thread_local=*/true, colocated_buffers, + colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, + assignment.get())); } // Mark all buffers which may be live out of the entry computation as // "liveout". - auto entry = module->entry_computation(); - auto root_instruction = entry->root_instruction(); - const PointsToSet& root_points_to = - assignment->GetPointsToSet(root_instruction); - TF_RETURN_IF_ERROR(root_points_to.ForEachElement( - [&assignment](const ShapeIndex& /*index*/, bool /*is_leaf*/, - const std::vector& buffers) { - for (const LogicalBuffer* buffer : buffers) { - VLOG(3) << "maybe_live_out LogicalBuffer: " << buffer->ToString(); - if (assignment->HasAllocation(*buffer)) { - BufferAllocation* alloc = - assignment->GetMutableAssignedAllocation(*buffer); - alloc->set_maybe_live_out(true); - VLOG(3) << "maybe_live_out BufferAllocation: " << alloc->ToString(); - } - } - return tensorflow::Status::OK(); - })); - - XLA_VLOG_LINES(2, assignment->ToString()); - - // Compute sizes of various kinds of buffers for logging. - int64 total_size = 0; - int64 parameter_size = 0; - for (auto& allocation : assignment->Allocations()) { - if (allocation.is_entry_computation_parameter()) { - parameter_size += allocation.size(); + for (const LogicalBuffer* buffer : + assignment->liveness().maybe_live_out_buffers()) { + VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer; + if (assignment->HasAllocation(*buffer)) { + BufferAllocation* alloc = + assignment->GetMutableAssignedAllocation(*buffer); + alloc->set_maybe_live_out(true); + VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc; } - total_size += allocation.size(); } - // Compute the total size of the output. Iterate over the subshapes and sum up - // the sizes of the buffers for each subshape. - int64 output_size = 0; - HloInstruction* root = module->entry_computation()->root_instruction(); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape( - root->shape(), [this, &output_size, root, &assignment]( - const Shape& /*subshape*/, const ShapeIndex& index) { - const auto& allocations = assignment->GetAllocations(root, index); - if (!allocations.empty()) { - output_size += allocations.begin()->size(); - } - return tensorflow::Status::OK(); - })); + // Combines allocations of temporary buffers into one big BufferAllocation. + // This can only be performed after all buffers have been assigned, and after + // maybe_live_out is marked, since it is used to determine whether an + // allocation contains temporary buffers or not. + assignment->CombineTempAllocations(); - VLOG(1) << "Allocation sizes for module " << module->name() << ":"; - VLOG(1) << " parameter allocation total size: " - << tensorflow::strings::HumanReadableNumBytes(parameter_size); - VLOG(1) << " output allocation total size: " - << tensorflow::strings::HumanReadableNumBytes(output_size); - VLOG(1) << " temp allocation total size: " - << tensorflow::strings::HumanReadableNumBytes( - total_size - parameter_size - output_size); - VLOG(1) << " total allocation size: " - << tensorflow::strings::HumanReadableNumBytes(total_size); + XLA_VLOG_LINES(2, assignment->ToString()); + TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats()); + XLA_VLOG_LINES(1, assignment->GetStats().ToString()); return std::move(assignment); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index e7aeb35967e..b3933f11c1e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -20,19 +20,21 @@ limitations under the License. #include #include #include -#include -#include #include #include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -41,12 +43,15 @@ limitations under the License. namespace xla { // This class abstracts an allocation of contiguous memory which can hold the -// values described by LogicalBuffers. A BufferAllocation may hold different -// LogicalBuffers at different times, but currently never more than one -// LogicalBuffer simultaneously. The abstraction includes information required -// by the backends for allocation, use, and deallocation of the buffer. This -// includes the LogicalBuffers which are held in this allocation through the -// execution of the computation. +// values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range +// of the allocation, represented by a Slice. A single BufferAllocation may hold +// LogicalBuffers with disjoint liveness, which may have overlapping Slices. A +// single BufferAllocation may also hold LogicalBuffers with overlapping +// liveness, which must have disjoint Slices. +// +// The abstraction includes information required by the backends for allocation, +// use, and deallocation of the buffer. This includes the LogicalBuffers which +// are held in this allocation through the execution of the computation. class BufferAllocation { public: // Holds a unique identifier for each allocation. Values are assigned @@ -54,15 +59,16 @@ class BufferAllocation { using Index = int64; BufferAllocation(Index index, int64 size, bool is_thread_local, - bool is_reusable) + bool is_reusable, LogicalBuffer::Color color) : index_(index), size_(size), is_thread_local_(is_thread_local), - is_reusable_(is_reusable) {} + is_reusable_(is_reusable), + color_(color) {} ~BufferAllocation() {} - // Adds a LogicalBuffer to the set assigned to this buffer. - void AddAssignment(const LogicalBuffer& buffer); + // Returns the index of this allocation. + Index index() const { return index_; } // Whether this allocation is used in a parallel calling context such as // inside of a map or reduce computation. Such allocations need to be thread @@ -84,30 +90,83 @@ class BufferAllocation { CHECK(is_entry_computation_parameter_); return parameter_number_; } - // Sets that this allocation holds a LogicalBuffer from a parameter of the - // entry computation. - void set_entry_computation_parameter(int64 parameter_number) { - is_entry_computation_parameter_ = true; - parameter_number_ = parameter_number; - } - // Returns/sets whether this allocation is assigned a LogicalBuffer which may + // Returns whether this allocation is assigned a LogicalBuffer which may // be live out of the entry computation. bool maybe_live_out() const { return maybe_live_out_; } - void set_maybe_live_out(bool value) { maybe_live_out_ = value; } // Returns the size of the allocation. Necessarily this must be at least as // large as any LogicalBuffer assigned to this allocation. int64 size() const { return size_; } - // Access to the logical buffers assigned to this allocation. - const std::vector& assigned_buffers() const { + // Returns the color of the allocation. Only logical buffers with a matching + // color can reside in this allocation. + LogicalBuffer::Color color() const { return color_; } + + struct OffsetSize { + int64 offset = 0; + int64 size = 0; + }; + + // Access to the logical buffers assigned to this allocation, and their + // associated logical offsets and sizes. + const tensorflow::gtl::FlatMap& + assigned_buffers() const { return assigned_buffers_; } - Index index() const { return index_; } + // A Slice represents a contiguous portion of a memory allocation. It is used + // to identify the memory range that a LogicalBuffer corresponds to. + class Slice { + public: + Slice() {} + Slice(const BufferAllocation* allocation, int64 offset, int64 size) + : allocation_(allocation), offset_(offset), size_(size) {} + + const BufferAllocation* allocation() const { return allocation_; } + Index index() const { return allocation_->index(); } + int64 offset() const { return offset_; } + int64 size() const { return size_; } + + bool operator==(const Slice& other) const { + return index() == other.index() && offset_ == other.offset_ && + size_ == other.size_; + } + bool operator!=(const Slice& other) const { return !(*this == other); } + bool operator<(const Slice& other) const { + if (index() != other.index()) return index() < other.index(); + if (offset_ != other.offset_) return offset_ < other.offset_; + return size_ < other.size_; + } + + // Returns true iff this slice's memory range has a non-empty intersection + // with the other slice's memory range. + bool OverlapsWith(const Slice& other) const { + const int64 end = offset_ + size_; + const int64 other_end = other.offset_ + other.size_; + return index() == other.index() && offset_ < other_end && + end > other.offset_; + } + + struct Hasher { + size_t operator()(Slice s) const; + }; + + string ToString() const; + + private: + const BufferAllocation* allocation_ = nullptr; + int64 offset_ = 0; + int64 size_ = 0; + }; + + // GetSlice returns the Slice of contiguous memory that holds the value + // described by the given 'buffer'. + // REQUIRES: 'buffer' must be assigned to this allocation. + Slice GetSlice(const LogicalBuffer& buffer) const; string ToString() const; + BufferAllocationProto ToProto() const; // Whether the buffer is a parameter to or live out of the entry computation. bool IsInputOrOutput() const { @@ -137,6 +196,21 @@ class BufferAllocation { } private: + // Only BufferAssigner and BufferAssignment can modify BufferAllocation. + friend class BufferAssigner; + friend class BufferAssignment; + + // Adds a LogicalBuffer to the set assigned to this buffer. + void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); + + void set_entry_computation_parameter(int64 parameter_number) { + is_entry_computation_parameter_ = true; + parameter_number_ = parameter_number; + } + void set_maybe_live_out(bool value) { maybe_live_out_ = value; } + void set_index(Index index) { index_ = index; } + void set_size(int64 size) { size_ = size; } + // The index of the allocation in the BufferAssignment. Index index_; @@ -149,6 +223,9 @@ class BufferAllocation { // Whether this buffer is usable by more than one logical buffer. bool is_reusable_; + // Color of the allocation. + LogicalBuffer::Color color_; + // Whether this allocation holds an entry computation parameter. Entry // computation parameters are special be cause they have lifetimes which may // outlast the computation. @@ -164,12 +241,14 @@ class BufferAllocation { // might not actually escape. bool maybe_live_out_ = false; - // The set of buffers assigned to this allocation. - std::vector assigned_buffers_; + // Mapping from the set of buffers assigned to this allocation to their + // logical offsets and sizes. + tensorflow::gtl::FlatMap assigned_buffers_; }; -// Add stream operator for nicer output of CHECK/RET_CHECK failures. +// Add stream operators for nicer output of CHECK/RET_CHECK failures. std::ostream& operator<<(std::ostream& out, const BufferAllocation& s); +std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s); // This class encapsulates an assignment of the LogicalBuffers in an XLA // module to a set of BufferAllocations. @@ -180,6 +259,11 @@ class BufferAssignment { return allocations_; } + // Returns the total size allocation holding all temporary buffers. + int64 temp_allocation_total_size() const { + return temp_allocation_total_size_; + } + // Returns whether the given buffer has been assigned an allocation. bool HasAllocation(const LogicalBuffer& buffer) const; @@ -192,29 +276,28 @@ class BufferAssignment { // with the given index. const BufferAllocation& GetAllocation(BufferAllocation::Index index) const; - // Builds and returns a vector containing the allocations which might contain - // the subvalue at the given index of given instruction. - std::set GetAllocations(const HloInstruction* instruction, - const ShapeIndex& index) const; + // Builds and returns a vector containing the slices which might contain the + // subvalue at the given index of given instruction. + std::set GetAllSlices( + const HloInstruction* instruction, const ShapeIndex& index) const; // Convenience function which returns whether the top-level buffer of the // instruction (index == {}) is assigned an allocation. bool HasTopLevelAllocation(const HloInstruction* instruction) const; - // Convenience function which returns the unique buffer allocation containing - // the buffer at the given index of the given instruction. If an allocation is - // not assigned or the allocation cannot be determined at compile time then an - // error is returned. - StatusOr GetUniqueAllocation( + // Convenience function which returns the unique slice containing the buffer + // at the given index of the given instruction. If a slice is not assigned or + // the slice cannot be determined at compile time then an error is returned. + StatusOr GetUniqueSlice( const HloInstruction* instruction, const ShapeIndex& index) const; - // Like GetUniqueAllocation but fixes the index to the top-level of the shape + // Like GetUniqueSlice but fixes the index to the top-level of the shape // (index = {}). - StatusOr GetUniqueTopLevelAllocation( + StatusOr GetUniqueTopLevelSlice( const HloInstruction* instruction) const; - // Like GetUniqueTopLevelAllocation but returns the allocation for the output - // of the entry computation of the HLO module (ie, the result of the XLA + // Like GetUniqueTopLevelSlice but returns the slice for the output of the + // entry computation of the HLO module (ie, the result of the XLA // computation). - StatusOr GetUniqueTopLevelOutputAllocation() const; + StatusOr GetUniqueTopLevelOutputSlice() const; // Returns the set LogicalBuffers which may be the source of the value at the // given index and instruction. @@ -223,36 +306,75 @@ class BufferAssignment { return GetPointsToSet(instruction).element(index); } + // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}' + // share the same BufferAllocation::Slice. + // Returns false otherwise. + // REQUIRES: BufferAssignment assigned allocations to both instructions. + bool SharesSliceAtIndex(const HloInstruction* hlo_a, + const ShapeIndex& shape_index_a, + const HloInstruction* hlo_b, + const ShapeIndex& shape_index_b) const; + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); } + // Returns the BufferLiveness object used to construct this assignment. + const BufferLiveness& liveness() const { return *liveness_; } + string ToString() const; + BufferAssignmentProto ToProto() const; + + // Statistics for the assignment. Values initialized to -1 are not always + // collected; fragmentation is only collected for instructions that have a + // sequential total ordering. + struct Stats { + int64 parameter_allocation_count = 0; + int64 parameter_allocation_bytes = 0; + int64 maybe_live_out_allocation_count = 0; + int64 maybe_live_out_allocation_bytes = 0; + int64 preallocated_temp_allocation_count = 0; + int64 preallocated_temp_allocation_bytes = 0; + int64 preallocated_temp_fragmentation_bytes = -1; + int64 total_allocation_count = 0; + int64 total_allocation_bytes = 0; + int64 total_fragmentation_bytes = -1; + + string ToString() const; + }; + const Stats& GetStats() const { return stats_; } private: // Only BufferAssigner can build or modify BufferAssignments. friend class BufferAssigner; explicit BufferAssignment(const HloModule* module, - std::unique_ptr liveness) - : module_(module), liveness_(std::move(liveness)) {} + std::unique_ptr liveness, + int64 alignment, + LogicalBuffer::SizeFunction buffer_size) + : module_(module), + liveness_(std::move(liveness)), + alignment_(alignment), + buffer_size_(std::move(buffer_size)) {} - // Creates and returns a new BufferAllocation. Ownership is maintained - // internally. The allocation initially has only the given LogicalBuffer - // assigned to it. `is_thread_local` indicates whether this buffer needs to be - // thread-local. + // Creates and returns a new BufferAllocation, with no assigned + // LogicalBuffers. Ownership is maintained internally. + BufferAllocation* NewEmptyAllocation(int64 size, bool is_thread_local, + bool is_reusable, + LogicalBuffer::Color color); + + // Helper that calls NewEmptyAllocation and AddAssignment in one call, + // creating an allocation containing a single LogicalBuffer. BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size, bool is_thread_local, bool is_reusable); - // Adds a LogicalBuffer to the set assigned to the given allocation. If - // colocated_buffer is true, then the logical buffer is an alias of another - // buffer assigned to this allocation. - void AddAssignment(const LogicalBuffer& buffer, BufferAllocation* allocation, - bool colocated_buffer); + // Adds a LogicalBuffer to the set assigned to the given allocation. + void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, + int64 offset, int64 size); - // Returns the BufferLiveness object used to construct this assignment. - const BufferLiveness& liveness() { return *liveness_; } + // Returns the HloModule used to construct this assignment. + const HloModule& module() const { return *module_; } // Convenience function which returns the PointsToSet for the given // instruction. Extracted from the liveness object. @@ -262,15 +384,31 @@ class BufferAssignment { BufferAllocation* GetMutableAssignedAllocation(const LogicalBuffer& buffer); BufferAllocation* GetMutableAllocation(BufferAllocation::Index index); + // Combines allocations of temporary buffers into one big BufferAllocation. + void CombineTempAllocations(); + + // Computes stats for the assignment, to be retrieved by GetStats. + Status ComputeSummaryStats(); + // The vector of buffer allocations. Indexed by BufferAllocation::Index. std::vector allocations_; + // The total size of all temporary buffers. + int64 temp_allocation_total_size_ = 0; + // Maps Buffers to the index of the BufferAllocation which holds the buffer. - std::map + tensorflow::gtl::FlatMap allocation_index_for_buffer_; const HloModule* module_; - std::unique_ptr liveness_; + const std::unique_ptr liveness_; + const int64 alignment_; + + // Function which returns the buffer size for a given logical buffer (shape). + LogicalBuffer::SizeFunction buffer_size_; + + Stats stats_; + std::vector heap_simulator_traces_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment); }; @@ -280,50 +418,61 @@ class BufferAssigner { public: // Build and return a BufferAssignment for the given module. The given // HloOrdering is used to determine buffer liveness. buffer_size is a function - // which returns the size of a LogicalBuffer. If hlos_to_allocate is not null - // then only instructions in this vector are considered for buffer - // assignment. If hlos_to_allocate is null then all instructions are - // considered. If 'colocate_related_buffers' is true, related LogicalBuffers - // will be colocated in the same allocation (i.e buffers for while result - // will share an allocation with buffers related to that same while - // instruction: init operand, condition/body parameter and body result). + // which returns the size of a LogicalBuffer. Alignment is the minimum + // alignment of any buffer. allow_input_output_aliasing specifies whether + // input buffer are allowed to be reused as outbut buffers by the client code. static StatusOr> Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, bool colocate_related_buffers, - const std::vector* hlos_to_allocate = nullptr); - - // Overload of Run which uses ShapeUtil::ByteSizeOf to determine buffer size - // and assigns buffers to all HLO instructions in the module. - static StatusOr> Run( - const HloModule* module, std::unique_ptr hlo_ordering, - int64 pointer_size); + LogicalBuffer::SizeFunction buffer_size, int64 alignment, + bool allow_input_output_aliasing = false, + TuplePointsToAnalysis::Colorer colorer = + TuplePointsToAnalysis::DefaultColorer()); private: - explicit BufferAssigner(LogicalBuffer::SizeFunction buffer_size, - bool colocate_related_buffers) - : buffer_size_(std::move(buffer_size)), - colocate_related_buffers_(colocate_related_buffers) {} + BufferAssigner(int64 alignment, bool allow_input_output_aliasing, + TuplePointsToAnalysis::Colorer colorer) + : alignment_(alignment), + allow_input_output_aliasing_(allow_input_output_aliasing), + colorer_(colorer) {} virtual ~BufferAssigner() = default; // Create a buffer assignment. StatusOr> CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, - const std::vector* hlos_to_allocate = nullptr); + LogicalBuffer::SizeFunction buffer_size); // Assigns buffers to the instructions in the given computation. "assignment" // is modified to reflect the new buffer assignments. If is_thread_local is // true, then all assigned buffers have the is_thread_local flag set to - // true. If hlos_to_allocate is not null it indicates which HLOs to include in - // buffer assignment. If null, all instructions in the computation are - // included. - tensorflow::Status AssignBuffersForComputation( + // true. + Status AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet* hlos_to_allocate, const tensorflow::gtl::FlatSet& colocated_buffers, const tensorflow::gtl::FlatSet& colocated_allocations, + tensorflow::gtl::FlatMap>* + buffers_to_assign_sequentially, BufferAssignment* assignment); + // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming + // the HLO instructions will be executed in the sequential order given by + // assignment->liveness().hlo_ordering().SequentialOrder. If + // 'run_whole_module_heap_simulation' is true, the heap simulation will be run + // assuming all global computations are sequentially ordered. + Status AssignBuffersWithSequentialOrdering( + const tensorflow::gtl::FlatMap< + const HloComputation*, + tensorflow::gtl::FlatSet>& + buffers_to_assign_sequentially, + bool run_whole_module_heap_simulation, BufferAssignment* assignment); + + // Uses the results of the heap simulator to create a single allocation, with + // LogicalBuffers packed to specific offsets. + void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result, + BufferAssignment* assignment, + LogicalBuffer::Color color); + // Tries to assign the given instruction to the given buffer. Returns if the // assignment was successful. bool MaybeAssignBuffer(BufferAllocation* allocation, @@ -340,7 +489,8 @@ class BufferAssigner { // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' // which should be colocated in the same buffer allocation. void BuildColocatedBufferSets( - const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloModule* module, const BufferLiveness& buffer_liveness, + const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets); // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the @@ -357,13 +507,32 @@ class BufferAssigner { const std::vector& colocated_set, std::vector* colocated_buffer_sets); - const HloModule* module_; + // Conceptually the same as AddSetToColocatedBufferSets, but specific to the + // colocated buffers for while instructions. + void AddWhileSetToColocatedBufferSets( + const std::vector& colocated_set, + const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const HloComputation& computation, const BufferLiveness& buffer_liveness, + const LogicalBuffer::SizeFunction& buffer_size, + std::vector* colocated_buffer_sets); - // Function which returns the buffer size for a given shape. - LogicalBuffer::SizeFunction buffer_size_; + // Split a set of buffers into several sets, each of which contains buffers + // colored with the same color. + tensorflow::gtl::FlatMap, + LogicalBuffer::Color::Hasher> + SplitBuffersByColor( + const tensorflow::gtl::FlatSet& buffers); - // Indicates whether related buffers should share the same buffer allocation. - const bool colocate_related_buffers_; + // Minimum alignment of any buffer. + int64 alignment_; + + // If true, buffer assignments assumes that input parameter buffers and output + // buffers can be shared if their sizes match. + bool allow_input_output_aliasing_; + + // Functor used to assign colors to newly allocated logical buffers. + TuplePointsToAnalysis::Colorer colorer_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index b8841c35f68..892f67a8812 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -18,16 +18,23 @@ limitations under the License. #include #include #include +#include #include #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -74,6 +81,24 @@ class BufferAssignmentTest : public HloTestBase { BufferAssignmentTest() : computation_tracker_() {} ~BufferAssignmentTest() override {} + std::unique_ptr RunBufferAssignment(HloModule* module, + int64 alignment = 1) { + return BufferAssigner::Run( + module, MakeUnique(module), + backend_->compiler()->BufferSizeBytesFunction(), alignment) + .ConsumeValueOrDie(); + } + + std::unique_ptr RunColoredBufferAssignment( + HloModule* module, TuplePointsToAnalysis::Colorer colorer, + int64 alignment = 1) { + return BufferAssigner::Run(module, + MakeUnique(module), + backend_->compiler()->BufferSizeBytesFunction(), + alignment, false, std::move(colorer)) + .ConsumeValueOrDie(); + } + // Builds an x+1.0 computation to use in a Map. std::unique_ptr BuildMapComputationPlus1(const string& name) { auto builder = HloComputation::Builder(name); @@ -145,7 +170,7 @@ class BufferAssignmentTest : public HloTestBase { const BufferAssignment& buffers, HloInstruction* hlo) { LOG(INFO) << "Checking input: " << hlo->ToString(); const BufferAllocation& buffer = - *buffers.GetUniqueTopLevelAllocation(hlo).ConsumeValueOrDie(); + *buffers.GetUniqueTopLevelSlice(hlo).ConsumeValueOrDie().allocation(); EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number()); return buffer; } @@ -163,11 +188,13 @@ class BufferAssignmentTest : public HloTestBase { const BufferAllocation& GetAllocation(const BufferAssignment& buffers, const HloInstruction* hlo, const ShapeIndex& index) { - return *buffers.GetUniqueAllocation(hlo, index).ConsumeValueOrDie(); + return *buffers.GetUniqueSlice(hlo, index).ConsumeValueOrDie().allocation(); } const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers, const HloInstruction* hlo) { - return *buffers.GetUniqueTopLevelAllocation(hlo).ConsumeValueOrDie(); + return *buffers.GetUniqueTopLevelSlice(hlo) + .ConsumeValueOrDie() + .allocation(); } // Verifies that all instructions in the given instruction list except @@ -195,32 +222,6 @@ class BufferAssignmentTest : public HloTestBase { return total_size; } - // Returns true if the buffers assigned to instructions in "a" are distinct - // from the buffers assigned to those in "b" (ie, intersection is empty). - bool BuffersDistinct(const std::vector& a, - const std::vector& b, - const BufferAssignment& assignment) { - std::set a_buffers; - for (const HloInstruction* instruction : a) { - if (assignment.HasTopLevelAllocation(instruction)) { - a_buffers.insert(assignment.GetUniqueTopLevelAllocation(instruction) - .ConsumeValueOrDie() - ->index()); - } - } - - for (const HloInstruction* instruction : b) { - if (assignment.HasTopLevelAllocation(instruction)) { - if (a_buffers.count(assignment.GetUniqueTopLevelAllocation(instruction) - .ConsumeValueOrDie() - ->index())) { - return false; - } - } - } - return true; - } - // Computation tracker for nested computations. ComputationTracker computation_tracker_; @@ -235,12 +236,28 @@ class BufferAssignmentTest : public HloTestBase { Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_}); }; -namespace { -std::unique_ptr RunBufferAssignment(HloModule* module) { - return BufferAssigner::Run(module, MakeUnique(module), - /*pointer_size=*/sizeof(void*)) - .ConsumeValueOrDie(); -} +// Returns true if the buffers assigned to instructions in "a" are distinct +// from the buffers assigned to those in "b" (ie, intersection is empty). +static bool BuffersDistinct(const std::vector& a, + const std::vector& b, + const BufferAssignment& assignment) { + std::set a_slices; + for (const HloInstruction* instruction : a) { + if (assignment.HasTopLevelAllocation(instruction)) { + a_slices.insert( + assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie()); + } + } + + for (const HloInstruction* instruction : b) { + if (assignment.HasTopLevelAllocation(instruction)) { + if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction) + .ConsumeValueOrDie())) { + return false; + } + } + } + return true; } // Tests a computation consisting of a single scalar constant node. @@ -248,7 +265,7 @@ TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignment(module.get()); @@ -266,7 +283,7 @@ TEST_F(BufferAssignmentTest, BufferForConst) { LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignment(module.get()); @@ -284,7 +301,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignment(module.get()); @@ -311,7 +328,7 @@ TEST_F(BufferAssignmentTest, Basic) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignment(module.get()); @@ -331,7 +348,113 @@ TEST_F(BufferAssignmentTest, Basic) { // The add node can reuse the mul node's buffer. const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); - EXPECT_EQ(add_buffer.index(), add_buffer.index()); + EXPECT_EQ(add_buffer.index(), mul_buffer.index()); + + // The sub node has a valid output buffer assigned. + GetAssignedOutputAllocation(*buffers, sub); +} + +TEST_F(BufferAssignmentTest, BasicUniquelyColored) { + // paramscalar ------- (mul) -- (add) -- (sub) + // / / / + // param0[100] -------/ / / + // / / + // param1[100] --------------/--------/ + // The output of each op is colored with a different color, so we can not + // share anything. + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kSubtract, add, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunColoredBufferAssignment( + module.get(), + [](const HloInstruction* instruction, const ShapeIndex& index) { + static int64 serial = 0; + return LogicalBuffer::Color(serial++); + }); + + // Distinct input buffers were assigned for parameters. + BufferAllocation paramscalar_buffer = + GetAssignedInputAllocation(*buffers, paramscalar); + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1); + EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); + EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index()); + EXPECT_NE(param0_buffer.index(), param1_buffer.index()); + + // The mul node has a valid buffer assigned, doesn't share with input. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + EXPECT_NE(mul_buffer.index(), param0_buffer.index()); + + // The add node can not reuse the mul node's buffer due to coloring. + const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); + EXPECT_NE(add_buffer.index(), mul_buffer.index()); + + // The sub node has a valid output buffer assigned. + GetAssignedOutputAllocation(*buffers, sub); +} + +TEST_F(BufferAssignmentTest, BasicPartiallyColored) { + // paramscalar ------- (mul) -- (add) -- (sub) + // / / / + // param0[100] -------/ / / + // / / + // param1[100] --------------/--------/ + // The output of the mul and the add have the color 1, and the other buffers + // have the color 0, which allows the mul and add to share buffers. + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kSubtract, add, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunColoredBufferAssignment( + module.get(), + [](const HloInstruction* instruction, const ShapeIndex& index) { + return (instruction->opcode() == HloOpcode::kAdd || + instruction->opcode() == HloOpcode::kMultiply) + ? LogicalBuffer::Color(1) + : LogicalBuffer::Color(0); + }); + + // Distinct input buffers were assigned for parameters. + BufferAllocation paramscalar_buffer = + GetAssignedInputAllocation(*buffers, paramscalar); + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1); + EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); + EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index()); + EXPECT_NE(param0_buffer.index(), param1_buffer.index()); + + // The mul node has a valid buffer assigned, doesn't share with input. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + EXPECT_NE(mul_buffer.index(), param0_buffer.index()); + + // The add node can reuse the mul node's buffer. + const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); + EXPECT_EQ(add_buffer.index(), mul_buffer.index()); // The sub node has a valid output buffer assigned. GetAssignedOutputAllocation(*buffers, sub); @@ -361,7 +484,7 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignment(module.get()); @@ -396,7 +519,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // param0[100x10] ---> (map x+1) // // Builds the map function. - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto map_computation = module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); auto inner_last = map_computation->root_instruction(); @@ -451,7 +574,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { // out-of-order reductions could overwrite an element before a use.) // // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3) - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto reduce_computation = module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); @@ -502,7 +625,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // const4[f32[4]] --- tuple --- while[condition, body] // // Builds the nested condition and body. - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto condition_computation = module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4")); auto body_computation = @@ -553,15 +676,14 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // Check that buffer for each subshape of 'while_op' shares allocation with // corresponding buffer from while body computation at same index. - TF_CHECK_OK(ShapeUtil::ForEachSubshape( + ShapeUtil::ForEachSubshape( while_op->shape(), [this, &buffers, while_op, body_root](const Shape& /*subshape*/, const ShapeIndex& index) { auto while_op_allocation = GetAllocation(*buffers, while_op, index); auto body_root_allocation = GetAllocation(*buffers, body_root, index); EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index()); - return Status::OK(); - })); + }); // Log size information for inspection. LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() @@ -583,7 +705,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { auto neg = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -608,11 +730,11 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { auto negate = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto slice = builder.AddInstruction( - HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -640,12 +762,12 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { auto negate = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto slice = builder.AddInstruction( - HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -677,12 +799,12 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { auto tuple_element = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0)); auto slice = builder.AddInstruction( - HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10})); + HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1})); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -703,30 +825,29 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { // // param ---> (negate) ---> (slice) ---> (broadcast) // - // The negate should *not* share a buffer with broadcast. + // Neither negate nor slice may share a buffer with broadcast. auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, f32vec100_, "param0")); // Negate output is 100 elements. auto negate = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + // Slice output is 10 elements. auto slice = builder.AddInstruction( - HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); // Broadcast output is 40 elements. auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); - // The instructions should not share buffers. + // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), GetTopLevelAllocation(*assignment, negate)); EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), GetTopLevelAllocation(*assignment, slice)); - EXPECT_NE(GetTopLevelAllocation(*assignment, negate), - GetTopLevelAllocation(*assignment, slice)); } TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { @@ -745,12 +866,12 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { auto negate = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto slice = builder.AddInstruction( - HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); // Broadcast output is 40 elements. auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 10}), slice, {0})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -773,38 +894,37 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { // // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple) // - // The negate should *not* share a buffer with broadcast. + // Neither negate nor slice may share a buffer with broadcast. auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, f32vec100_, "param0")); // Negate output is 100 elements. auto negate = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + // Slice output is 10 elements. auto slice = builder.AddInstruction( - HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10})); + HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); // Broadcast output is 40 elements. auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); builder.AddInstruction(HloInstruction::CreateTuple({broadcast})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); - // The instructions should not share buffers. + // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), GetTopLevelAllocation(*assignment, negate)); EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), GetTopLevelAllocation(*assignment, slice)); - EXPECT_NE(GetTopLevelAllocation(*assignment, negate), - GetTopLevelAllocation(*assignment, slice)); } TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { // Verify that buffers for embedded computations are properly marked as // thread-local and that embedded parameters are not marked as // is_entry_computation_parameter. - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto vec_shape = ShapeUtil::MakeShape(F32, {42}); auto scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -849,8 +969,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { EXPECT_FALSE(map_root_alloc.maybe_live_out()); EXPECT_TRUE(map_root_alloc.is_thread_local()); - // Allocations for the call computation should not be thread-local and not - // live-out. + // Allocations for the call computation should not be thread-local. auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param); EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter()); EXPECT_FALSE(call_param_alloc.maybe_live_out()); @@ -858,7 +977,6 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root); EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter()); - EXPECT_FALSE(call_root_alloc.maybe_live_out()); EXPECT_FALSE(call_root_alloc.is_thread_local()); // Entry computation allocations can be marked liveout and @@ -883,7 +1001,7 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { ShapeUtil::MakeShape(S32, {42})}), "param0")); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -893,7 +1011,7 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { // Verify each buffer allocation is marked as an entry computation parameter // and is liveout. - TF_CHECK_OK(ShapeUtil::ForEachSubshape( + ShapeUtil::ForEachSubshape( tuple_param->shape(), [this, &assignment, tuple_param](const Shape& /*subshape*/, const ShapeIndex& index) { @@ -901,8 +1019,7 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { EXPECT_TRUE(allocation.is_entry_computation_parameter()); EXPECT_EQ(0, allocation.parameter_number()); EXPECT_TRUE(allocation.maybe_live_out()); - return Status::OK(); - })); + }); } TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { @@ -919,7 +1036,7 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -962,7 +1079,7 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), LiteralUtil::CreateR0(1).get()}))); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -976,7 +1093,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), ShapeUtil::MakeShape(S32, {101})}), /*operands=*/{}, /*custom_call_target=*/"foo_function")); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -991,7 +1108,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { TEST_F(BufferAssignmentTest, TupleCallAsOutput) { // Test a computation which returns a tuple call value. - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto elem_shape = f32vec4_; auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); @@ -1024,6 +1141,75 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { GetTopLevelAllocation(*assignment, sub_param)); } +TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { + // Test a chain of calls with tuple output. The chain looks like: + // A: call(B, tuple(param)) + // B: call(C, param) + // C: call(D, param) + // D: param + auto module = CreateNewModule(); + auto elem_shape = f32vec4_; + auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); + + auto d_builder = HloComputation::Builder(TestName() + "_d"); + auto d_param = d_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "d_param")); + auto d_computation = d_builder.Build(); + + auto c_builder = HloComputation::Builder(TestName() + "_c"); + auto c_param = c_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "c_param")); + auto c_call = c_builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get())); + auto c_computation = c_builder.Build(); + + auto b_builder = HloComputation::Builder(TestName() + "_b"); + auto b_param = b_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "b_param")); + auto b_call = b_builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get())); + auto b_computation = b_builder.Build(); + + auto a_builder = HloComputation::Builder(TestName()); + auto a_param = a_builder.AddInstruction( + HloInstruction::CreateParameter(0, elem_shape, "param")); + auto a_tuple = + a_builder.AddInstruction(HloInstruction::CreateTuple({a_param})); + auto a_call = a_builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get())); + auto a_computation = a_builder.Build(); + + // Add the computations in an order that doesn't match the dependency + // post-order, to shake out more possible bugs. + module->AddEmbeddedComputation(std::move(d_computation)); + module->AddEmbeddedComputation(std::move(c_computation)); + module->AddEntryComputation(std::move(a_computation)); + module->AddEmbeddedComputation(std::move(b_computation)); + + auto assignment = RunBufferAssignment(module.get()); + + // Buffers for call are co-located with the sub-computations. + EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), + GetAllocation(*assignment, b_call, /*index=*/{})); + EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}), + GetAllocation(*assignment, c_call, /*index=*/{})); + EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}), + GetAllocation(*assignment, d_param, /*index=*/{})); + EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}), + GetAllocation(*assignment, b_call, /*index=*/{0})); + EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}), + GetAllocation(*assignment, c_call, /*index=*/{0})); + EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}), + GetAllocation(*assignment, d_param, /*index=*/{0})); + // The parameters aren't aliased with anything. + EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment)); + EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment)); + EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment)); + EXPECT_TRUE(BuffersDistinct({b_param}, {c_param}, *assignment)); + EXPECT_TRUE(BuffersDistinct({b_param}, {d_param}, *assignment)); + EXPECT_TRUE(BuffersDistinct({c_param}, {d_param}, *assignment)); +} + TEST_F(BufferAssignmentTest, BitcastAsOutput) { // Test a computation which returns a bitcast value. auto builder = HloComputation::Builder(TestName()); @@ -1032,7 +1218,7 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { auto bitcast = builder.AddInstruction( HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -1058,7 +1244,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -1066,19 +1252,20 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { // buffer and receives its own allocation. auto select_alloc = GetTopLevelAllocation(*assignment, select); EXPECT_EQ(1, select_alloc.assigned_buffers().size()); - EXPECT_EQ(select, select_alloc.assigned_buffers()[0]->instruction()); + EXPECT_EQ(select, + select_alloc.assigned_buffers().begin()->first->instruction()); // The buffer for the tuple element of the select is forwarded from one its - // operands which cannot be determined statically. Therefore its allocation - // should include the allocations of both of the elements in the parameters. - auto element_allocations = assignment->GetAllocations(select, /*index=*/{0}); - EXPECT_EQ(2, element_allocations.size()); - EXPECT_MATCH(testing::SetToVec(element_allocations), - testing::UnorderedMatcher( - *assignment->GetUniqueAllocation(tuple_param0, /*index=*/{0}) - .ConsumeValueOrDie(), - *assignment->GetUniqueAllocation(tuple_param1, /*index=*/{0}) - .ConsumeValueOrDie())); + // operands which cannot be determined statically. Therefore its slices + // should include the slices of both of the elements in the parameters. + auto element_slices = assignment->GetAllSlices(select, /*index=*/{0}); + EXPECT_EQ(2, element_slices.size()); + EXPECT_THAT(element_slices, + ::testing::UnorderedElementsAre( + assignment->GetUniqueSlice(tuple_param0, /*index=*/{0}) + .ConsumeValueOrDie(), + assignment->GetUniqueSlice(tuple_param1, /*index=*/{0}) + .ConsumeValueOrDie())); } // TODO(b/34669761): Remove this test when buffers are allowed to share @@ -1095,7 +1282,7 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) { auto copy = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape, HloOpcode::kCopy, tuple_element)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto assignment = RunBufferAssignment(module.get()); @@ -1106,6 +1293,330 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) { GetTopLevelAllocation(*assignment, copy)); } -} // namespace +TEST_F(BufferAssignmentTest, OneTempAllocation) { + // Test a computation that requires multiple temp buffers, and ensure they are + // combined into a single allocation. + auto builder = HloComputation::Builder(TestName()); + Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3}); + Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4}); + Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4}); + Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4}); + // There should be separate temp buffers for dot_ab and dot_bc. + auto param_a = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape_2x3, "param_a")); + auto param_b = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape_3x4, "param_b")); + auto param_c = builder.AddInstruction( + HloInstruction::CreateParameter(2, shape_4x4, "param_c")); + auto dot_ab = builder.AddInstruction(HloInstruction::CreateBinary( + shape_2x4, HloOpcode::kDot, param_a, param_b)); + auto dot_bc = builder.AddInstruction(HloInstruction::CreateBinary( + shape_3x4, HloOpcode::kDot, param_b, param_c)); + builder.AddInstruction( + HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); + + // Run buffer assignment with alignment=1. + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1); + + // There are 5 allocations: 3 parameters, 1 output, and 1 temp. + EXPECT_EQ(5, assignment->Allocations().size()); + + // Ensure the temp buffers for dot_ab and dot_bc share a single allocation, + // and each occupies different slices of that allocation. + BufferAllocation::Slice slice_ab = + assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); + BufferAllocation::Slice slice_bc = + assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); + EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation()); + EXPECT_NE(slice_ab, slice_bc); + EXPECT_EQ(32, slice_ab.size()); + EXPECT_EQ(48, slice_bc.size()); + EXPECT_EQ(80, slice_ab.allocation()->size()); + EXPECT_EQ(80, slice_bc.allocation()->size()); + + // Re-run buffer assignment with alignment=64. + assignment = RunBufferAssignment(module.get(), /*alignment=*/64); + EXPECT_EQ(5, assignment->Allocations().size()); + slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); + slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); + EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation()); + EXPECT_NE(slice_ab, slice_bc); + EXPECT_EQ(32, slice_ab.size()); + EXPECT_EQ(48, slice_bc.size()); + // Ensure the offsets and allocation size account for the alignment, without + // assuming which buffer gets assigned first. + if (slice_ab.offset() == 0) { + EXPECT_EQ(64, slice_bc.offset()); + EXPECT_EQ(64 + 48, slice_ab.allocation()->size()); + EXPECT_EQ(64 + 48, slice_bc.allocation()->size()); + } else { + EXPECT_EQ(64, slice_ab.offset()); + EXPECT_EQ(0, slice_bc.offset()); + EXPECT_EQ(64 + 32, slice_ab.allocation()->size()); + EXPECT_EQ(64 + 32, slice_bc.allocation()->size()); + } +} + +class WhileBufferAssignmentTest : public HloTestBase { + protected: + std::unique_ptr BuildWhileConditionComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto ten = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); + return builder.Build(); + } + + std::unique_ptr BuildWhileBodyComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto input = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0)); + auto weights = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + auto output = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kMultiply, input, weights)); + builder.AddInstruction( + HloInstruction::CreateTuple({input, weights, output})); + return builder.Build(); + } + + std::unique_ptr RunBufferAssignment(HloModule* module, + int64 alignment = 1) { + auto sequence = + CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + return BufferAssigner::Run( + module, MakeUnique(module, sequence), + ByteSizeOf, alignment) + .ConsumeValueOrDie(); + } + + static int64 ByteSizeOf(const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*)); + } + + Shape data_shape_ = ShapeUtil::MakeShape(F32, {4}); + Shape loop_state_shape_ = + ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_}); +}; + +static void RunCopyInsertion(HloModule* module) { + CopyInsertion copy_insertion; + EXPECT_IS_OK(copy_insertion.Run(module).status()); +} + +TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder("entry"); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + auto weights1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, data_shape_, "weights1")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + + auto cond1 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body1 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + auto input1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input1, weights1, output1})); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + + module->AddEntryComputation(builder.Build()); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + + // Verify 'input0' and read-only use while0{0} alias. + EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie()); + // Verify 'weights0' and read-only use while0{1} alias. + EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie()); + // Verify 'while0{2}' and read-only use while1{0} alias. + EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie()); + // Verify 'weights1' and read-only use while1{1} alias. + EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); +} + +TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder("entry"); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + + auto cond1 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body1 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output1})); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + + module->AddEntryComputation(builder.Build()); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + + // while0 and while1 buffers should be completely aligned. + EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie()); + EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); + EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie()); +} + +TEST_F(BufferAssignmentTest, TwoCalls) { + auto module = MakeUnique(TestName()); + Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); + HloComputation* sub_computation; + { + auto builder = HloComputation::Builder(TestName() + "_sub_comp"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param")); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1)); + sub_computation = module->AddEmbeddedComputation(builder.Build(add)); + } + auto builder = HloComputation::Builder(TestName()); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + auto call1 = builder.AddInstruction( + HloInstruction::CreateCall(r0f32, {constant2}, sub_computation)); + auto call2 = builder.AddInstruction( + HloInstruction::CreateCall(r0f32, {constant3}, sub_computation)); + auto add1 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2)); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1)); + module->AddEntryComputation(builder.Build(add2)); + + { + FlattenCallGraph flatten; + TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + } + + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); +} + +// Test buffer assignment for while nodes with multiple uses. +// TODO(b/37245345): Fix buffer assignment for this case. +TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder(TestName()); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0)); + + auto get0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); + auto get1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1)); + module->AddEntryComputation(builder.Build()); + + RunCopyInsertion(module.get()); + + { + FlattenCallGraph flatten; + TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + EXPECT_TRUE(result); + } + + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); +} + +} // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index c788c643069..1b14c26340f 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -17,11 +17,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" -#include #include #include #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -37,15 +37,17 @@ namespace xla { /* static */ StatusOr> BufferLiveness::Run( - const HloModule* module, std::unique_ptr hlo_ordering) { + const HloModule* module, std::unique_ptr hlo_ordering, + TuplePointsToAnalysis::Colorer colorer) { std::unique_ptr liveness( - new BufferLiveness(module, std::move(hlo_ordering))); + new BufferLiveness(module, std::move(hlo_ordering), std::move(colorer))); TF_RETURN_IF_ERROR(liveness->Analyze()); return std::move(liveness); } tensorflow::Status BufferLiveness::Analyze() { - TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); + TF_ASSIGN_OR_RETURN(points_to_analysis_, + TuplePointsToAnalysis::Run(module_, colorer_)); for (auto& computation : module_->computations()) { // Gather all instructions whose buffers might alias other instructions into // the set aliased_buffers_. This includes those contained as a tuple @@ -61,11 +63,9 @@ tensorflow::Status BufferLiveness::Analyze() { } if (computation.get() == module_->entry_computation()) { - for (const LogicalBuffer* live_out_buffer : - points_to_analysis_->GetPointsToSet(computation->root_instruction()) - .CreateFlattenedSet()) { - maybe_live_out_buffers_.insert(live_out_buffer); - } + const HloInstruction* root = computation->root_instruction(); + maybe_live_out_buffers_ = + points_to_analysis_->GetPointsToSet(root).CreateFlattenedSet(); } } @@ -92,19 +92,6 @@ string BufferLiveness::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } -// Returns false if 'user' cannot possibly use the buffer at 'index' in -// 'operand'. Returns true otherwise. -// Precondition: 'operand' is an operand of 'user'. -bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user) { - if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return false; - } - return true; -} - bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, const LogicalBuffer& b) const { TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a)); @@ -117,7 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { for (auto user : alias.instruction()->users()) { - if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user)) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, + points_to_analysis())) { continue; } if (user != b.instruction() && @@ -127,23 +115,44 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, } } - // If 'b' is a user of 'a' then the buffers interfere if b is not an - // elementwise operation emitting the same shape/layout as 'a'. + // If 'b' is a user of 'a' then the buffers interfere unless 'a.instruction' + // and 'b.instruction' emit the same shape/layout, and 'b.instruction' meets + // the qualifications specified in CanShareOperandBufferWithUser. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { - if (alias.instruction()->users().count(b.instruction()) > 0 && - (!ShapeUtil::Equal(alias.instruction()->shape(), - b.instruction()->shape()) || - !b.instruction()->IsElementwise())) { + if (b.instruction()->IsUserOf(alias.instruction()) && + !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), + b.instruction(), b.index(), + points_to_analysis())) { return false; } } return true; } +namespace { +bool IsEntryParameter(const HloInstruction* instruction) { + const HloComputation* computation = instruction->parent(); + return instruction->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation(); +} +} // namespace + bool BufferLiveness::MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const { - return (!live_range_strictly_before(a, b) && - !live_range_strictly_before(b, a)); + // Entry parameters live at the entry of the execution, thus always interfere + // with all other instructions executing before them in the ordering. + const HloInstruction* a_instruction = a.instruction(); + const HloInstruction* b_instruction = b.instruction(); + if (IsEntryParameter(a_instruction) && + hlo_ordering_->ExecutesBefore(b_instruction, a_instruction)) { + return true; + } + if (IsEntryParameter(b_instruction) && + hlo_ordering_->ExecutesBefore(a_instruction, b_instruction)) { + return true; + } + // Buffers without disjoint liveness may interfere. + return !live_range_strictly_before(a, b) && !live_range_strictly_before(b, a); } bool BufferLiveness::MaybeLiveOut(const LogicalBuffer& buffer) const { diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index b9e7a2a28db..9bb2564a831 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -39,7 +39,9 @@ class BufferLiveness { // Constructs a buffer liveness object for the given module assuming the given // HLO instruction ordering. static StatusOr> Run( - const HloModule* module, std::unique_ptr hlo_ordering); + const HloModule* module, std::unique_ptr hlo_ordering, + TuplePointsToAnalysis::Colorer colorer = + TuplePointsToAnalysis::DefaultColorer()); // Returns true if the live range of the buffer containing the output of 'a' // may overlap with the live range of the buffer of 'b'. If instruction 'a' @@ -51,17 +53,29 @@ class BufferLiveness { // the entry computation. bool MaybeLiveOut(const LogicalBuffer& buffer) const; + // Returns the complete set of buffers that may be live out of the module. + const tensorflow::gtl::FlatSet& maybe_live_out_buffers() + const { + return maybe_live_out_buffers_; + } + // Returns the underlying points-to analysis used for this liveness analysis. const TuplePointsToAnalysis& points_to_analysis() const { return *points_to_analysis_; } + // Returns the underlying hlo ordering used for this liveness analysis. + const HloOrdering& hlo_ordering() const { return *hlo_ordering_; } + string ToString() const; private: explicit BufferLiveness(const HloModule* module, - std::unique_ptr hlo_ordering) - : module_(module), hlo_ordering_(std::move(hlo_ordering)) {} + std::unique_ptr hlo_ordering, + TuplePointsToAnalysis::Colorer colorer) + : module_(module), + hlo_ordering_(std::move(hlo_ordering)), + colorer_(colorer) {} // Perform buffer liveness analysis. This method must be called prior to // MayInterfere or MaybeLiveOut. @@ -84,6 +98,8 @@ class BufferLiveness { tensorflow::gtl::FlatSet maybe_live_out_buffers_; std::unique_ptr points_to_analysis_; + + TuplePointsToAnalysis::Colorer colorer_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 1ca5768dbe1..fda44ff4d2d 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -92,6 +92,12 @@ class BufferLivenessTest : public HloTestBase { GetBuffer(liveness, instruction, /*index=*/{})); } + std::unique_ptr BuildDummyComputation() { + auto builder = HloComputation::Builder(TestName() + "_dummy"); + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + return builder.Build(); + } + const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42}); }; @@ -110,7 +116,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto log = builder.AddInstruction( HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -118,12 +124,17 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { MakeUnique(module.get())) .ConsumeValueOrDie(); - // No buffers should interfere. EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log)); + + // No buffers should interfere. EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, log)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log)); - EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, exp)); // Buffers should interfere with itself. EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp)); @@ -135,22 +146,73 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log)); } -TEST_F(BufferLivenessTest, NonElementwiseOperand) { - // A chain of operations with one elementwise and one non-elementwise. The - // elementwise op should not interfere with its operand, while the - // non-elementwise op should interfere. +TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { + // Two entry params, which interfere with each other. // - // param --> negate -> reverse + // param0 --> negate ---------------\ + // param1 --> exp --> add + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec_, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, vec_, "param1")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param0)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param1)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); + + auto module = CreateNewModule(); + HloComputation* entry = module->AddEntryComputation(builder.Build()); + + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param0, negate, param1, exp, add}}); + auto liveness = BufferLiveness::Run( + module.get(), + MakeUnique(module.get(), sequence)) + .ConsumeValueOrDie(); + + // Entry parameters interfere as if they are defined simultaneously at + // the very beginning. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, param0, param1)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, add)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, param0)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, add)); + + // Negate and exp still interfere. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate)); + + // But {negate, add} and {exp, add} don't interfere. + EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); +} + +TEST_F(BufferLivenessTest, NonElementwiseOperand) { + // A chain of operations with two elementwise and one non-elementwise. The + // elementwise op should not interfere with its operand, while the + // non-elementwise op should interfere. Entry params always interfere. + // + // param --> exp -> negate -> reverse // auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param)); auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); + HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, exp)); auto reverse = builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -158,10 +220,14 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { MakeUnique(module.get())) .ConsumeValueOrDie(); - // No buffers should interfere. + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, reverse)); + + // Negate is elementwise, so doesn't interfere with its operand. + // Reverse is non-elementwise, so does interfere with its operand. + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate)); EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse)); - EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); } TEST_F(BufferLivenessTest, OverlappedBuffers) { @@ -180,7 +246,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -190,8 +256,15 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp)); - EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); } TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { @@ -204,8 +277,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { // Sequential order: // param, negate, exp, add // - // Liveness is identical to the DependencyHloOrdering except that 'param' and - // exp no longer interfere. + // Liveness is identical to the DependencyHloOrdering. auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); @@ -216,7 +288,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); SequentialHloOrdering::HloModuleSequence module_sequence; @@ -229,8 +301,15 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); - EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); } TEST_F(BufferLivenessTest, TupleLiveOut) { @@ -251,7 +330,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto outer_tuple = builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -271,7 +350,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { TEST_F(BufferLivenessTest, EmbeddedComputation) { // Test MaybeLiveOut and MayInterfere for embedded computation. - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); auto embedded_param = embedded_builder.AddInstruction( @@ -328,7 +407,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0->shape(), tuple_constant, 0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -391,8 +470,9 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = MakeUnique(TestName()); - module->AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(BuildDummyComputation()); + module->AddEmbeddedComputation(builder.Build()); auto liveness = BufferLiveness::Run(module.get(), @@ -451,8 +531,9 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = MakeUnique(TestName()); - module->AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(BuildDummyComputation()); + module->AddEmbeddedComputation(builder.Build()); auto liveness = BufferLiveness::Run(module.get(), @@ -482,6 +563,229 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1})); } +class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { + protected: + // Builds and runs a computation (see test case computation graphs below). + // Runs BufferLiveness on this computation. + // Returns whether buffer interference is detected between tuple-shaped + // parameter and root instructions at tuple element 1. + bool Run(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto builder = HloComputation::Builder(TestName()); + // Create param0 Tuple. + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {3}); + auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0")); + + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0)); + + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); + + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + HloInstruction* slice = nullptr; + if (update_uses_tuple_element1) { + // Create a slice instruction as an additional user of 'gte1'. + slice = builder.AddInstruction( + HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1})); + update = builder.AddInstruction(HloInstruction::CreateBinary( + update_shape, HloOpcode::kAdd, update, slice)); + } + // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + // Create output tuple. + auto tuple_root = builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + // Build module and get reference to entry computation. + auto module = CreateNewModule(); + module->AddEntryComputation(BuildDummyComputation()); + auto* computation = module->AddEmbeddedComputation(builder.Build()); + // Create fusion instruction based on number of tuple element 1 users. + if (update_uses_tuple_element1) { + computation->CreateFusionInstruction( + {dynamic_update_slice, starts, update, CHECK_NOTNULL(slice), gte1}, + HloInstruction::FusionKind::kLoop); + } else { + computation->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + } + // Create fusion instruction for tuple element 0 (if requested). + if (fuse_gte0) { + computation->CreateFusionInstruction({gte0}, + HloInstruction::FusionKind::kLoop); + } + + // Run BufferLiveness on 'module'. + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + // Return whether or not buffers interference is detected between + // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); + } +}; + +// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) +// do not overlap with the following computation: +// +// Param0 +// / \ +// GTE(0) Fusion -----------> FusionParam +// | | | +// | | GTE(1) Const Const +// | | \ | / +// | | DynamicUpdateSlice // fused root +// \ / +// Tuple // computation root +// +TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { + EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false)); +} + +// Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases +// 'fusion1') do not overlap in the presence of another fusion instruction +// (which is a user of 'param0' at a different tuple index). +// BufferLiveness should detect no uses of Param0 at index {1} in Fusion0 +// (because Fusion0 only uses Param0 at index {0}). +// +// Param0 +// / \ +// FusionParam <----- Fusion0 Fusion1 ------> FusionParam +// | | | | +// GTE(0) | | GTE(1) Const Const +// | | \ | / +// \ / DynamicUpdateSlice +// Tuple +// +TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { + EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true)); +} + +// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) +// do overlap because GTE(1) has two users: +// 1) DynamicUpdateSlice at operand 0. +// 2) Slice at operand 0. +// +// Param0 +// / \ Const +// / \ / +// GTE(0) Fusion -----------> FusionParam FusionParam +// | | | | +// | | GTE(1) / +// | | | \ / +// | | | Slice / +// | | | \ / +// | | | Add Const +// | | | | | +// | | DynamicUpdateSlice // fused root +// \ / +// Tuple // computation root +// +TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) { + EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true)); +} + +class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { + protected: + // Builds and runs a computation (see test case computation graphs below). + // Runs BufferLiveness on this computation. + // Returns whether buffer interference is detected between tuple-shaped + // parameter and root instructions at tuple element 1. + bool Run(const bool tuple_element1_has_two_uses) { + auto builder = HloComputation::Builder(TestName()); + // Create param0 Tuple. + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {3}); + auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0")); + + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0)); + + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); + + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + + if (tuple_element1_has_two_uses) { + // Add 'gte0' and 'gte1' to create another user of 'gte1'. + gte0 = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, gte0, gte1)); + } + // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + // Create output tuple. + auto tuple_root = builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + // Build module and get reference to entry computation. + auto module = CreateNewModule(); + module->AddEntryComputation(BuildDummyComputation()); + module->AddEmbeddedComputation(builder.Build()); + // Run BufferLiveness on 'module'. + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique(module.get())) + .ConsumeValueOrDie(); + // Return whether or not buffers interference is detected between + // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); + } +}; + +// Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in +// the following computation (because DynamicUpdateSlice (at operand 0) is the +// unique user): +// +// Parameter0 +// | | +// GTE(0) GTE(1) Const Const +// | \ | / +// | DynamicUpdateSlice +// \ / +// Tuple +// +TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) { + EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false)); +} + +// Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because +// GTE(1) has two users: +// 1) DynamicUpdateSlice at operand 0. +// 2) Add at operand 1. +// +// Parameter0 +// | | +// GTE(0) GTE(1) +// | / | +// | / | +// Add | Const Const +// | | | | +// | DynamicUpdateSlice +// \ / +// Tuple +// +TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) { + EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true)); +} + } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc new file mode 100644 index 00000000000..fa7b2a30952 --- /dev/null +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -0,0 +1,306 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/call_graph.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +using ::tensorflow::strings::Appendf; +using ::tensorflow::strings::StrCat; + +string CallContextToString(CallContext context) { + switch (context) { + case CallContext::kNone: + return "kNone"; + case CallContext::kSequential: + return "kSequential"; + case CallContext::kParallel: + return "kParallel"; + case CallContext::kBoth: + return "kBoth"; + } +} + +std::ostream& operator<<(std::ostream& out, const CallContext& context) { + out << CallContextToString(context); + return out; +} + +CallContext GetInstructionCallContext(const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kWhile: + return CallContext::kSequential; + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kFusion: + return CallContext::kParallel; + default: + return CallContext::kNone; + } +} + +string CallSite::ToString() const { + return StrCat(instruction()->name(), " calls in context ", + CallContextToString(context()), ": ", + tensorflow::str_util::Join( + called_computations(), ", ", + [](string* out, const HloComputation* computation) { + out->append(computation->name()); + })); +} + +CallGraphNode::CallGraphNode(HloComputation* computation) + : computation_(computation) {} + +const CallSite* CallGraphNode::GetCallSite( + const HloInstruction* instruction) const { + auto it = callsite_instructions_.find(instruction); + if (it == callsite_instructions_.end()) { + return nullptr; + } + return &callsites_[it->second]; +} + +void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) { + caller_callsites_.push_back(caller_callsite); + HloComputation* caller = caller_callsite.instruction()->parent(); + if (!ContainsKey(caller_set_, caller)) { + callers_.push_back(caller); + caller_set_.insert(caller); + } +} + +void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { + CHECK_EQ(instruction->parent(), computation()); + const CallContext context = GetInstructionCallContext(instruction); + if (!instruction->called_computations().empty()) { + CHECK(context == CallContext::kSequential || + context == CallContext::kParallel); + callsite_instructions_.insert({instruction, callsites_.size()}); + callsites_.push_back( + CallSite(instruction, instruction->called_computations(), context)); + // Update callee computations to include any new computations called by this + // instruction. + for (auto* callee : callsites_.back().called_computations()) { + if (!ContainsKey(callee_set_, callee)) { + callees_.push_back(callee); + callee_set_.insert(callee); + } + } + } +} + +CallGraph::CallGraph(const HloModule* module) : module_(module) {} + +const CallGraphNode& CallGraph::GetNode( + const HloComputation* computation) const { + auto it = node_indices_.find(computation); + CHECK(it != node_indices_.end()); + return nodes_[it->second]; +} + +CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { + auto it = node_indices_.find(computation); + CHECK(it != node_indices_.end()); + return nodes_[it->second]; +} + +namespace { + +// Returns the call context of a computation which is called from contexts 'a' +// and 'b'. +CallContext UnionContexts(CallContext a, CallContext b) { + if (a == CallContext::kNone) { + return b; + } else if (b == CallContext::kNone) { + return a; + } else if (a == b) { + return a; + } else { + // Contexts are different and neither is kNone, ie one is kSequential and + // the other is kParallel. + return CallContext::kBoth; + } +} + +} // namespace + +void CallGraph::SetCallContexts() { + std::queue worklist; + + // Initialize worklist with all roots of the call graph (computations without + // callers). + for (const std::unique_ptr& computation : + module_->computations()) { + CallGraphNode& node = GetNode(computation.get()); + if (node.callers().empty()) { + node.set_context(CallContext::kSequential); + worklist.push(&node); + } + } + + while (!worklist.empty()) { + CallGraphNode* node = worklist.front(); + worklist.pop(); + + for (const CallSite& callsite : node->callsites()) { + for (const HloComputation* callee : callsite.called_computations()) { + CallGraphNode& callee_node = GetNode(callee); + + // Update context of callee computation based on the callsite and its + // current context. + CallContext context_to_add; + if (callsite.context() == CallContext::kParallel) { + context_to_add = CallContext::kParallel; + } else { + CHECK_EQ(callsite.context(), CallContext::kSequential); + context_to_add = node->context(); + } + CallContext new_context = + UnionContexts(context_to_add, callee_node.context()); + + if (new_context != callee_node.context()) { + // Context of computation has been changed so add node to worklist. + callee_node.set_context(new_context); + worklist.push(&callee_node); + } + } + } + } + + // No node should have a kNone calling context. + for (const std::unique_ptr& computation : + module_->computations()) { + CHECK_NE(GetNode(computation.get()).context(), CallContext::kNone); + } +} + +/* static */ +std::unique_ptr CallGraph::Build(const HloModule* module) { + // Constructor for CallGraph is private so MakeUnique can't be used. + auto call_graph = WrapUnique(new CallGraph(module)); + + VLOG(2) << "Building call graph for:"; + XLA_VLOG_LINES(2, module->ToString()); + + // Construct nodes of the call graph and populate the callsites. + for (const std::unique_ptr& computation : + module->computations()) { + auto it_added = call_graph->node_indices_.insert( + {computation.get(), call_graph->nodes_.size()}); + // All computations should be unique, so the computation should not already + // exist in the map. + CHECK(it_added.second); + call_graph->nodes_.emplace_back(computation.get()); + + // Add all callsites in this computation. + for (const std::unique_ptr& instruction : + computation->instructions()) { + call_graph->nodes_.back().AddCallSiteForInstruction(instruction.get()); + } + } + + // Add caller callsites to each node. + for (const std::unique_ptr& computation : + module->computations()) { + for (const CallSite& callsite : + call_graph->GetNode(computation.get()).callsites()) { + for (auto* callee : callsite.called_computations()) { + // Add caller callsites. + call_graph->GetNode(callee).AddCallerCallSite(callsite); + } + } + } + + call_graph->SetCallContexts(); + XLA_VLOG_LINES(1, call_graph->ToString()); + + return call_graph; +} + +Status CallGraph::VisitNodesInternal( + const VisitorFunction& visitor_func, const CallGraphNode& node, + tensorflow::gtl::FlatSet* visited) const { + auto pair = visited->insert(&node); + if (!pair.second) { + // Node was not inserted. Node has already been visited. + return Status::OK(); + } + + for (const HloComputation* computation : node.callees()) { + TF_RETURN_IF_ERROR( + VisitNodesInternal(visitor_func, GetNode(computation), visited)); + } + + return visitor_func(node); +} + +Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, + bool visit_unreachable_nodes) const { + tensorflow::gtl::FlatSet visited; + if (visit_unreachable_nodes) { + // Traverse from all roots in the call graph. + for (const CallGraphNode& node : nodes()) { + if (node.callers().empty()) { + TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited)); + } + } + } else { + // Traverse only from the entry computation. + TF_RETURN_IF_ERROR(VisitNodesInternal( + visitor_func, GetNode(module_->entry_computation()), &visited)); + } + + return Status::OK(); +} + +string CallGraph::ToString() const { + string out; + Appendf(&out, "Call graph for module %s:\n", module_->name().c_str()); + for (const CallGraphNode& node : nodes()) { + Appendf(&out, "Computation %s:\n", node.computation()->name().c_str()); + Appendf(&out, " calls:\n"); + for (const HloComputation* callee : node.callees()) { + Appendf(&out, " %s\n", callee->name().c_str()); + } + Appendf(&out, " called by:\n"); + for (const HloComputation* caller : node.callers()) { + Appendf(&out, " %s\n", caller->name().c_str()); + } + Appendf(&out, " callsites:\n"); + for (const CallSite& callsite : node.callsites()) { + Appendf(&out, " %s\n", callsite.ToString().c_str()); + } + } + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h new file mode 100644 index 00000000000..7f9990f06d4 --- /dev/null +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -0,0 +1,221 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Call graph for an HLO module. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// The context in which a computation is called by another computation. +enum class CallContext { + // In a parallel contex the computation is applied to each element of the + // array argument(s). kMap and kReduce instructions call computations in + // parallel context. + kParallel, + + // In a sequential context the computation is applied to the entire argument + // shape(s). kCall and kWhile (body and condition) call computations in + // sequential context. + kSequential, + + // A computation is called from both a parallel and sequential context. + kBoth, + + // During call graph construction kNone is used to indicate that the context + // has not been determined. This is the top value for the context + // lattice. After construction, no call sites or call graph nodes should have + // this value. + kNone +}; + +string CallContextToString(CallContext context); +std::ostream& operator<<(std::ostream& out, const CallContext& context); + +CallContext GetInstructionCallContext(const HloInstruction* instruction); + +// Represents an HLO instruction which calls one or more computations. +class CallSite { + public: + CallSite(HloInstruction* instruction, + const std::vector& called_computations, + CallContext context) + : instruction_(CHECK_NOTNULL(instruction)), + called_computations_(called_computations), + context_(context) {} + + // Returns the instruction associated with this call site. + HloInstruction* instruction() const { return instruction_; } + + // Returns the computations called at this call site. + const std::vector& called_computations() const { + return called_computations_; + } + + // Returns the context in which computations are called at this call site. + CallContext context() const { return context_; } + + string ToString() const; + + private: + // The calling instruction. + HloInstruction* instruction_; + + // The computations called by this callsite. + const std::vector called_computations_; + + // The context in which the computations are called. + const CallContext context_; +}; + +// A node in the call graph representing an HLO computation. +class CallGraphNode { + public: + CallGraphNode(HloComputation* computation); + + // Returns the computation represented by this call graph node. + HloComputation* computation() const { return computation_; } + + // Returns the call sites in this computation. These are the instructions in + // this computation which call other computations. + const std::vector& callsites() const { return callsites_; } + + // Returns the callsite associated with the given instruction. If this + // instruction calls no computations nullptr is returned. + // Prerequisite: instruction is in the computation associated with this call + // graph node. + const CallSite* GetCallSite(const HloInstruction* instruction) const; + + // Returns the computations called by this computation. + const std::vector& callees() const { return callees_; } + + // Returns the call sites in other computations which call this computation. + const std::vector& caller_callsites() const { + return caller_callsites_; + } + + // Returns the computations which call this computation. + const std::vector& callers() const { return callers_; } + + // Returns the context in which this computation is called. + CallContext context() const { return context_; } + + string ToString() const; + + private: + // Only CallGraph can modify CallGraphNode. + friend class CallGraph; + + // Sets the context in which this computation is called. + void set_context(CallContext value) { context_ = value; } + + // Adds a callsite which calls this computation. Updates callers to include + // the calling computation. + void AddCallerCallSite(const CallSite& caller_callsite); + + // If instruction calls any computations adds a call site for this instruction + // to the call graph node. If the instruction calls no computations then no + // call site is added. + void AddCallSiteForInstruction(HloInstruction* instruction); + + // Computation represented by this call graph node. + HloComputation* computation_; + + // The computations called by this computation. The vector is used for a + // stable ordering and the set enables fast membership testing. + std::vector callees_; + tensorflow::gtl::FlatSet callee_set_; + + // The computations which call this computation. The vector is used for a + // stable ordering and the set enables fast membership testing. + std::vector callers_; + tensorflow::gtl::FlatSet caller_set_; + + // The call sites in this computation + std::vector callsites_; + + // The map from instruction to index in callsites_ for looking up the callsite + // (if any) associated with a particular instruction in this computation. + tensorflow::gtl::FlatMap callsite_instructions_; + + // The call sites in other computations which call this computation. + std::vector caller_callsites_; + + // The context in which this computation is called. + CallContext context_ = CallContext::kNone; +}; + +// The call graph for an HLO module. The graph includes a node for each +// computation in the module. +class CallGraph { + public: + using VisitorFunction = std::function; + + // Builds and returns a call graph for the given HLO module. + static std::unique_ptr Build(const HloModule* module); + + // Returns the node associated with the given computation. + const CallGraphNode& GetNode(const HloComputation* computation) const; + CallGraphNode& GetNode(const HloComputation* computation); + + // Returns the vector of all nodes in the call graph. + const std::vector& nodes() const { return nodes_; } + + // Calls the given function on each node in the call graph. Nodes are visited + // in post order (callees before callers). If visit_unreachable_nodes is true + // then all nodes in the call graph are visited. Otherwise only those nodes + // reachable from the entry computation are visited. + Status VisitNodes(const VisitorFunction& visitor_func, + bool visit_unreachable_nodes = true) const; + + string ToString() const; + + private: + CallGraph(const HloModule* module); + + // Sets the call contexts for every node in the graph. + void SetCallContexts(); + + // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS + // post order (callee before caller) calling visitor_func on each node. Adds + // nodes to 'visited' as each node is visited. Skips nodes already in + // 'visited'. + Status VisitNodesInternal( + const VisitorFunction& visitor_func, const CallGraphNode& node, + tensorflow::gtl::FlatSet* visited) const; + + // The HLO module represented by this call graph. + const HloModule* module_ = nullptr; + + // Vector of all nodes in the call graph. + std::vector nodes_; + + // Map from HLO computation to the index of the corresponding call graph node + // in nodes_. + tensorflow::gtl::FlatMap node_indices_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc new file mode 100644 index 00000000000..e276473c90a --- /dev/null +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -0,0 +1,391 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/call_graph.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::UnorderedElementsAre; + +class CallGraphTest : public HloTestBase { + protected: + // Build and return a trivial computation taking and returning a scalar. + std::unique_ptr MakeScalarComputation() { + HloComputation::Builder builder(TestName() + ".ScalarComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); + return builder.Build(); + } + + // Build and return a computation which takes a scalar and maps (kMap) the + // given computation to the value 'callsites' number of times. + std::unique_ptr MakeMappingComputation( + HloComputation* map_computation, int64 callsites) { + HloComputation::Builder builder(TestName() + ".MappingComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* last_value = param0; + for (int64 i = 0; i < callsites; ++i) { + last_value = builder.AddInstruction(HloInstruction::CreateMap( + kScalarShape, {last_value}, map_computation)); + } + return builder.Build(); + } + + // Build and return a computation which takes a scalar and calls (kCall) the + // given computation with value 'callsites' number of times. + std::unique_ptr MakeCallingComputation( + HloComputation* callee_computation, int64 callsites, + const string& suffix = ".CallingComputation") { + HloComputation::Builder builder(TestName() + suffix); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* last_value = param0; + for (int64 i = 0; i < callsites; ++i) { + last_value = builder.AddInstruction(HloInstruction::CreateCall( + kScalarShape, {last_value}, callee_computation)); + } + return builder.Build(); + } + + // Build and return a computation which takes a scalar and returns a PRED + // value. + std::unique_ptr MakeConditionComputation() { + HloComputation::Builder builder(TestName() + ".ConditionComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + return builder.Build(); + } + + const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(CallGraphTest, SingletonComputation) { + // Test the call graph of a module with a single computation. + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(MakeScalarComputation()); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + EXPECT_EQ(1, call_graph->nodes().size()); + const CallGraphNode& node = call_graph->GetNode(computation); + EXPECT_EQ(computation, node.computation()); + EXPECT_TRUE(node.callsites().empty()); + EXPECT_TRUE(node.callees().empty()); + EXPECT_TRUE(node.caller_callsites().empty()); + EXPECT_TRUE(node.callers().empty()); + EXPECT_EQ(CallContext::kSequential, node.context()); +} + +TEST_F(CallGraphTest, UnreachableComputation) { + // Test the call graph of a module with an entry computation and an + // unreachable computation. + auto module = CreateNewModule(); + HloComputation* entry_computation = + module->AddEntryComputation(MakeScalarComputation()); + HloComputation* unreachable_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + EXPECT_EQ(2, call_graph->nodes().size()); + + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); + + const CallGraphNode& unreachable_node = + call_graph->GetNode(unreachable_computation); + EXPECT_EQ(unreachable_computation, unreachable_node.computation()); + EXPECT_EQ(CallContext::kSequential, unreachable_node.context()); +} + +TEST_F(CallGraphTest, ParallelComputation) { + // Test a call graph of a module with an entry computation which calls another + // computation in a parallel context via kMap. + auto module = CreateNewModule(); + HloComputation* map_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + HloComputation* entry_computation = module->AddEntryComputation( + MakeMappingComputation(map_computation, /*callsites=*/5)); + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + EXPECT_EQ(2, call_graph->nodes().size()); + + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); + EXPECT_EQ(5, entry_node.callsites().size()); + EXPECT_EQ(1, entry_node.callees().size()); + EXPECT_TRUE(entry_node.caller_callsites().empty()); + EXPECT_TRUE(entry_node.callers().empty()); + + const CallGraphNode& map_node = call_graph->GetNode(map_computation); + EXPECT_EQ(map_computation, map_node.computation()); + EXPECT_EQ(CallContext::kParallel, map_node.context()); + EXPECT_TRUE(map_node.callsites().empty()); + EXPECT_TRUE(map_node.callees().empty()); + EXPECT_EQ(5, map_node.caller_callsites().size()); + EXPECT_EQ(1, map_node.callers().size()); +} + +TEST_F(CallGraphTest, SequentialComputations) { + // Test a call graph of a module with an entry computation which calls another + // computation in a sequential context via kCall. + auto module = CreateNewModule(); + HloComputation* called_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + HloComputation* entry_computation = module->AddEntryComputation( + MakeCallingComputation(called_computation, /*callsites=*/3)); + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + EXPECT_EQ(2, call_graph->nodes().size()); + + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); + EXPECT_EQ(3, entry_node.callsites().size()); + EXPECT_EQ(1, entry_node.callees().size()); + EXPECT_TRUE(entry_node.caller_callsites().empty()); + EXPECT_TRUE(entry_node.callers().empty()); + + const CallGraphNode& called_node = call_graph->GetNode(called_computation); + EXPECT_EQ(called_computation, called_node.computation()); + EXPECT_EQ(CallContext::kSequential, called_node.context()); + EXPECT_TRUE(called_node.callsites().empty()); + EXPECT_TRUE(called_node.callees().empty()); + EXPECT_EQ(3, called_node.caller_callsites().size()); + EXPECT_EQ(1, called_node.callers().size()); +} + +TEST_F(CallGraphTest, ContextBothComputations) { + // Test a call graph of a module with an entry computation which calls another + // computation in both a parallel and sequential context. + auto module = CreateNewModule(); + HloComputation* subcomputation = + module->AddEmbeddedComputation(MakeScalarComputation()); + + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(kScalarShape, {param0}, subcomputation)); + HloInstruction* map = builder.AddInstruction( + HloInstruction::CreateMap(kScalarShape, {call}, subcomputation)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + EXPECT_EQ(2, call_graph->nodes().size()); + + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(2, entry_node.callsites().size()); + + const CallSite& call_callsite = entry_node.callsites()[0]; + EXPECT_EQ(call, call_callsite.instruction()); + EXPECT_THAT(call_callsite.called_computations(), + UnorderedElementsAre(subcomputation)); + EXPECT_EQ(CallContext::kSequential, call_callsite.context()); + EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite); + + const CallSite& map_callsite = entry_node.callsites()[1]; + EXPECT_EQ(map, map_callsite.instruction()); + EXPECT_THAT(map_callsite.called_computations(), + UnorderedElementsAre(subcomputation)); + EXPECT_EQ(CallContext::kParallel, map_callsite.context()); + EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite); + + const CallGraphNode& sub_node = call_graph->GetNode(subcomputation); + EXPECT_EQ(CallContext::kBoth, sub_node.context()); +} + +TEST_F(CallGraphTest, ComplexGraph) { + // Test a call graph of a module with several computation called in various + // contexts. The call graph looks like: + // + // entry + // / | + // a | + // / | \ | + // b | cond + // \ | + // c + // + // Calls are made via kCall, kWhile, and kMap instructions. + auto module = CreateNewModule(); + HloComputation* cond_computation = + module->AddEmbeddedComputation(MakeConditionComputation()); + HloComputation* c_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + HloComputation* b_computation = module->AddEmbeddedComputation( + MakeMappingComputation(c_computation, /*callsites=*/1)); + + HloComputation* a_computation; + { + HloComputation::Builder builder(TestName() + ".a"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(kScalarShape, {param0}, c_computation)); + builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, b_computation, call)); + a_computation = module->AddEmbeddedComputation(builder.Build()); + } + + HloComputation* entry_computation; + { + HloComputation::Builder builder(TestName() + ".entry"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, a_computation, param0)); + entry_computation = module->AddEntryComputation(builder.Build()); + } + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + EXPECT_EQ(5, call_graph->nodes().size()); + + // Entry computation has one while instruction calling two computations + // (cond_computation and a_computation). + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + ASSERT_EQ(1, entry_node.callsites().size()); + const std::vector& called_computations = + entry_node.callsites()[0].called_computations(); + EXPECT_THAT(called_computations, + UnorderedElementsAre(cond_computation, a_computation)); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); + + const CallGraphNode& c_node = call_graph->GetNode(c_computation); + EXPECT_TRUE(c_node.callsites().empty()); + EXPECT_THAT(c_node.callers(), + UnorderedElementsAre(a_computation, b_computation)); + EXPECT_EQ(CallContext::kBoth, c_node.context()); + + // Visit the graph and verify nodes were visited in callee-before-caller + // order. + std::vector visited; + TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { + visited.push_back(node.computation()); + return Status::OK(); + })); + EXPECT_EQ(visited.size(), 5); + // All values in visited should be unique. + EXPECT_EQ( + std::unordered_set(visited.begin(), visited.end()) + .size(), + 5); + + // Verify visitation order of some computations in the graph. + auto index_of = [&visited](const HloComputation* comp) { + auto it = std::find(visited.begin(), visited.end(), comp); + EXPECT_NE(it, visited.end()); + return std::distance(visited.begin(), it); + }; + EXPECT_EQ(4, index_of(entry_computation)); + EXPECT_LT(index_of(cond_computation), index_of(a_computation)); + EXPECT_LT(index_of(c_computation), index_of(b_computation)); + EXPECT_LT(index_of(b_computation), index_of(a_computation)); +} + +TEST_F(CallGraphTest, VisitSingletonComputation) { + // Test the call graph visitor with a call graph with a single node. + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(MakeScalarComputation()); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + + std::vector visited; + TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { + visited.push_back(node.computation()); + return Status::OK(); + })); + EXPECT_THAT(visited, UnorderedElementsAre(computation)); +} + +TEST_F(CallGraphTest, VisitUnreachableComputation) { + // Test the call graph visitor with a call graph with an unreachable node. + auto module = CreateNewModule(); + HloComputation* entry_computation = + module->AddEntryComputation(MakeScalarComputation()); + HloComputation* unreachable_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + + // Test visitation of only reachable nodes. + { + std::vector visited; + TF_ASSERT_OK(call_graph->VisitNodes( + [&visited](const CallGraphNode& node) { + visited.push_back(node.computation()); + return Status::OK(); + }, + /*visit_unreachable_nodes=*/false)); + EXPECT_EQ(visited.size(), 1); + EXPECT_EQ(visited[0], entry_computation); + } + + // Test visitation of all nodes (reachable and unreachable). + { + std::vector visited; + TF_ASSERT_OK(call_graph->VisitNodes( + [&visited](const CallGraphNode& node) { + visited.push_back(node.computation()); + return Status::OK(); + }, + /*visit_unreachable_nodes=*/true)); + EXPECT_EQ(visited.size(), 2); + EXPECT_THAT(visited, UnorderedElementsAre(entry_computation, + unreachable_computation)); + } +} + +TEST_F(CallGraphTest, VisitWithError) { + // Test that the call graph visitor properly propagates errors. + auto module = CreateNewModule(); + module->AddEntryComputation(MakeScalarComputation()); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + + Status status = call_graph->VisitNodes( + [](const CallGraphNode&) { return InternalError("Visitation failed"); }); + + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.code(), tensorflow::error::INTERNAL); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("Visitation failed")); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc new file mode 100644 index 00000000000..0d1a439724a --- /dev/null +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compile_only_service.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +/* static */ StatusOr> +CompileOnlyService::NewService(perftools::gputools::Platform* platform) { + ServiceOptions default_options; + default_options.set_platform(platform); + return NewService(default_options); +} + +/* static */ StatusOr> +CompileOnlyService::NewService(const ServiceOptions& options) { + perftools::gputools::Platform* platform = options.platform(); + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, + CreateComputeConstantBackend()); + std::unique_ptr service( + new CompileOnlyService(compiler, std::move(compute_constant_backend))); + return std::move(service); +} + +CompileOnlyService::CompileOnlyService( + Compiler* compiler, std::unique_ptr compute_constant_backend) + : Service(/*backend=*/nullptr, std::move(compute_constant_backend)), + compiler_(compiler) { + runs_in_client_process_ = true; +} + +StatusOr>> +CompileOnlyService::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options) { + std::vector> hlo_modules; + for (const AotComputationInstance& instance : computations) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(instance.computation)); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + // Dump computation proto state if flag is set. + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + const string& directory_path = flags->xla_dump_computations_to; + if (!directory_path.empty()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr session_module, + computation_tracker_.SnapshotComputation(versioned_handle.handle)); + string filename = tensorflow::strings::StrCat( + "computation_", versioned_handle.handle.handle(), "__", + session_module->entry().name(), "__version_", + versioned_handle.version); + TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, + *session_module)); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + HloModuleConfig hlo_module_config(*program_shape); + hlo_module_config.set_debug_options( + legacy_flags::GetDebugOptionsFromFlags()); + auto* computation_layout = + hlo_module_config.mutable_entry_computation_layout(); + if (flags->xla_hlo_profile) { + hlo_module_config.enable_hlo_profiling(true); + } + for (int i = 0; i < instance.argument_layouts.size(); ++i) { + const Shape& argument_layout = *instance.argument_layouts[i]; + if (ShapeUtil::IsTuple(argument_layout)) { + return Unimplemented("tuple arguments not supported yet"); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + argument_layout)); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + *instance.result_layout)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + computation_tracker_.BuildHloModule( + versioned_handle, hlo_module_config, + /*include_unreachable_instructions=*/true)); + hlo_modules.push_back(std::move(hlo_module)); + } + + return compiler_->CompileAheadOfTime(std::move(hlo_modules), + MakeHloDumper(), options); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h new file mode 100644 index 00000000000..3358305c03c --- /dev/null +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// An XLA Service specialization for ahead-of-time compilation. This only +// instantiates a Compiler object for the relevant platform; it does not +// instantiate or require an execution backend. +class CompileOnlyService : public Service { + public: + // Factory for creating a CompileOnlyService. The parameter platform is the + // platform that the service should target. If platform is null then the + // default platform is used. + static StatusOr> NewService( + perftools::gputools::Platform* platform); + static StatusOr> NewService( + const ServiceOptions& options); + + // A description of a computation to compile using CompileAheadOfTime. + struct AotComputationInstance { + ComputationHandle computation; + std::vector argument_layouts; + const Shape* result_layout = nullptr; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. See + // |CompileOnlyClient::CompileAheadOfTime| for additional details. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& Options); + + // Override Service methods that require or imply the existence of an + // execute backend. Note that this does not include TransferToClient, as + // computing constants produces global data that we may wish to transfer. + tensorflow::Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status GetDeviceHandles( + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { + return Unimplemented("CompileOnlyService does not support devices."); + } + tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status WaitForExecution( + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status TransferToServer( + const TransferToServerRequest* arg, + TransferToServerResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferToInfeed( + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { + return Unimplemented("CompileOnlyService does not support devices."); + } + + private: + explicit CompileOnlyService( + Compiler* compiler, std::unique_ptr compute_constant_backend); + CompileOnlyService(const CompileOnlyService&) = delete; + void operator=(const CompileOnlyService&) = delete; + + // The compiler for the target platform. This is included in place of + // the Service::execute_backend_'s compiler, since execute_backend_ is a + // nullptr in CompileOnlyService. + Compiler* compiler_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 85c2d03e1bc..7ae285170e4 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -110,29 +111,24 @@ class Compiler { // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. // - // TODO(leary) will need to update this API when a single computation can run - // across multiple devices simultaneously. + // Use the overload below to compile computations that run in parallel. virtual StatusOr> Compile( - std::unique_ptr module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, perftools::gputools::StreamExecutor* executor) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. virtual StatusOr>> Compile( - std::vector> hlo_module, - std::vector> module_config, - HloDumper dump_hlo, + std::vector> modules, HloDumper dump_hlo, std::vector stream_exec) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> - CompileAheadOfTime( - std::vector> module, - std::vector> module_config, - HloDumper dump_hlo, const AotCompilationOptions& options) = 0; + CompileAheadOfTime(std::vector> modules, + HloDumper dump_hlo, + const AotCompilationOptions& options) = 0; ///// // The Compiler class also serves as a point to register compiler objects @@ -153,6 +149,19 @@ class Compiler { static StatusOr GetForPlatform( const perftools::gputools::Platform* platform); + // Returns a function that computes the size in bytes of the logical + // buffer that contains a shape. + virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0; + + // Returns a function that computes the size in bytes of a given + // logical buffer. + std::function BufferSizeBytesFunction() { + HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); + return [shape_size](const LogicalBuffer& buffer) { + return shape_size(buffer.shape()); + }; + } + private: // Mutex that guards the platform-compiler map. static tensorflow::mutex* platform_compiler_mutex_; diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc index 281277bed57..9aa32a1fb76 100644 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ b/tensorflow/compiler/xla/service/computation_tracker.cc @@ -26,8 +26,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +using ::tensorflow::strings::Appendf; + namespace xla { ComputationTracker::ComputationTracker() : next_computation_(1) {} @@ -50,12 +53,28 @@ StatusOr ComputationTracker::LoadSessionModule( // For each embedded computation, create a new computation based on its // serialized data, and place the mapping from the old computation handle to // the new computation handle. + + // Build a mapping from old embedded computation handles to new computation + // handles. We build the ID mapping first since the embedded computations are + // in no particular order and may refer to each other. std::map old_to_new; for (const SessionComputation& computation : session_module.embedded_computations()) { const int64 old_handle = computation.computation_handle().handle(); - TF_ASSIGN_OR_RETURN(old_to_new[old_handle], - LoadSessionComputation(computation, &old_to_new)); + if (!old_to_new.emplace(old_handle, AllocateHandle()).second) { + return InvalidArgument("Duplicate embedded computation handle %lld", + old_handle); + } + } + + // Create a new computation from each serialized embedded computation. + for (const SessionComputation& computation : + session_module.embedded_computations()) { + const int64 old_handle = computation.computation_handle().handle(); + const ComputationHandle& new_handle = old_to_new[old_handle]; + TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], + UserComputation::MakeWithRemapping( + computation, new_handle, old_to_new)); } // Finally, place the entry computation in the tracker with all of the @@ -130,7 +149,7 @@ void ComputationTracker::ComputeComputationPostOrder( std::set* visited, std::list* post_order) const { if (visited->count(versioned_handle) > 0) { - DCHECK_EQ(1, visited->count(versioned_handle)); + CHECK_EQ(1, visited->count(versioned_handle)); return; } @@ -145,14 +164,19 @@ void ComputationTracker::ComputeComputationPostOrder( visited->insert(versioned_handle); post_order->push_back(versioned_handle); - return; } StatusOr> ComputationTracker::BuildHloModule( const VersionedComputationHandle& entry_handle, - bool include_unused_parameters) const { + const HloModuleConfig& config, + bool include_unreachable_instructions) const { tensorflow::mutex_lock lock(computation_mutex_); + VLOG(1) << "BuildHloModule(" << entry_handle + << ", include_unreachable_instructions=" + << include_unreachable_instructions << ")"; + XLA_VLOG_LINES(1, ToStringInternal()); + TF_ASSIGN_OR_RETURN(UserComputation * entry_computation, ResolveInternal(entry_handle.handle)); @@ -174,9 +198,17 @@ StatusOr> ComputationTracker::BuildHloModule( return hlo_computations.at(versioned_handle); }; + // Print the post-order list for this entry computation. + if (VLOG_IS_ON(2)) { + VLOG(2) << "Visiting UserComputations in post order:"; + for (const VersionedComputationHandle& versioned_handle : post_order) { + VLOG(2) << " " << versioned_handle; + } + } + string module_name = tensorflow::strings::StrCat(entry_computation->name(), "_module"); - auto module = MakeUnique(module_name, entry_handle); + auto module = MakeUnique(module_name, entry_handle, config); for (auto versioned_handle : post_order) { UserComputation* computation = ResolveInternal(versioned_handle.handle).ValueOrDie(); @@ -184,7 +216,7 @@ StatusOr> ComputationTracker::BuildHloModule( TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_computation, computation->BuildHloComputation(versioned_handle.version, resolver, - include_unused_parameters)); + include_unreachable_instructions)); // Add the newly created computation to VersionedHandle-to-HloComputation // map. @@ -201,4 +233,23 @@ StatusOr> ComputationTracker::BuildHloModule( return std::move(module); } +string ComputationTracker::ToString() const { + tensorflow::mutex_lock lock(computation_mutex_); + return ToStringInternal(); +} + +string ComputationTracker::ToStringInternal() const { + string out; + Appendf(&out, "ComputationTracker(%p):\n", this); + for (const auto& handle_computation : opaque_to_computation_) { + int64 handle = handle_computation.first; + const std::unique_ptr& computation = + handle_computation.second; + Appendf(&out, " %4lld : %s \"%s\"\n", handle, + computation->GetVersionedHandle().ToString().c_str(), + computation->name().c_str()); + } + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h index 7d0660d7f6d..d42d66adefe 100644 --- a/tensorflow/compiler/xla/service/computation_tracker.h +++ b/tensorflow/compiler/xla/service/computation_tracker.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" @@ -72,13 +73,18 @@ class ComputationTracker { // Builds an HLO module using the specified computation as the entry. The // module will include the entry computation as well as all computations which // are called directly or indirectly from the entry computation via operations - // like "map". If include_unused_parameters is true, then all parameters are - // lowered to HLO instructions even if they are not used. This ensures the - // entry HloComputation has the same program shape (ProgramShape) as the entry - // UserComputation. + // like "map". config is the HLO module configuration to use for the + // constructed module. + // If include_unreachable_instructions is true, then instructions + // which are not reachable from the root are lowered into HloInstructions + // including unreachable parameters. This ensures the entry HloComputation has + // the same program shape (ProgramShape) as the entry UserComputation. StatusOr> BuildHloModule( const VersionedComputationHandle& entry_handle, - bool include_unused_parameters = true) const; + const HloModuleConfig& config, + bool include_unreachable_instructions = true) const; + + string ToString() const; private: // Bumps the next_computation_ number and returns the allocated number wrapped @@ -117,6 +123,8 @@ class ComputationTracker { std::list* post_order) const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); + string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); + // Guards the computation mapping. Marked mutable so that the Resolve method // can remain const; Resolve does't really modify the tracker in any way, but // it has to lock the mutex for safety. diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 81f54c26ec5..a3803c34ba7 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -16,19 +16,20 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include -#include -#include #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" @@ -36,6 +37,9 @@ namespace xla { namespace { +using tensorflow::gtl::FlatMap; +using tensorflow::gtl::FlatSet; + // InstructionCopier encapsulates indices at which to copy 'instruction'. // All 'instruction' users in 'copy_users' are updated to use the copy. // @@ -52,7 +56,7 @@ namespace { // // Example two-element tuple with one element that needs a copy: // -// Tuple // instruction +// original-instruction // / \ // GTE(0) GTE(1) // | | @@ -60,23 +64,54 @@ namespace { // \ / // Tuple // copied-instruction // +// As an optimization, if the original instruction is itself a Tuple +// instruction, we elide the unnecessary extra GTE and Tuple instructions, +// and just insert the copy into a new Tuple instruction, with control +// dependencies to ensure the copy occurs after any possible interference. class InstructionCopier { public: - InstructionCopier(const bool init_value, HloInstruction* instruction, - const std::vector& copy_users); + InstructionCopier(HloInstruction* instruction, + const std::vector& copy_users) + : instruction_(instruction), + copy_users_(copy_users), + indices_to_copy_(instruction->shape()), + control_predecessors_(instruction->shape()) {} + + // Sets indices that are read-only, and thus do not need to be copied. + void SetReadOnlyIndices(const ShapeTree& read_only_indices) { + read_only_indices_ = read_only_indices; + } + + // Sets copy overrides, which are copy instructions to use at each index. This + // is used to share a single copy of read-only entry parameters and constants + // between multiple While loops. + void SetCopyOverrides(const ShapeTree& copy_overrides) { + copy_overrides_ = copy_overrides; + } // Returns true if all recorded indices are false (returns true otherwise). bool HasAllIndicesFalse() const; // Records instruction buffer indices which point-to a Parameter or Constant. - tensorflow::Status RecordIndicesWhichPointToParamOrConstant( + Status RecordIndicesWhichPointToParamOrConstant( const TuplePointsToAnalysis& points_to_analysis); // Records instruction buffer indices to copy which are necessary to ensure: // *) PointsToSet of 'instruction_' is unambiguous and distinct. // *) No liveness interference between 'instruction_' and 'other_instruction'. - tensorflow::Status RecordIndicesToCopyForColocatingBuffers( - BufferLiveness* liveness, HloInstruction* other_instruction); + // + // If 'read_only_indices_out' is non-null, read-only indices are set to true. + Status RecordIndicesToCopyForColocatingBuffers( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree* read_only_indices_out); + + // Records control predecessors to add for inserted copy instructions. + // 'parameter' must have the same shape as the instruction that will be + // copied, and must define all buffers in the shape. Control predecessors are + // only recorded for indices that have already been marked for copying. + Status RecordControlPredecessors( + const TuplePointsToAnalysis& points_to_analysis, + HloInstruction* parameter); // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy', // and replaces all uses for instructions in 'copy_users_' with copy. @@ -88,15 +123,29 @@ class InstructionCopier { const std::vector& copy_users() const { return copy_users_; } private: + // Does the given index represent a read-only buffer? + bool IsReadOnlyIndex(const ShapeIndex& index) const { + return !ShapeUtil::IsNil(read_only_indices_.shape()) && + read_only_indices_.element(index); + } + + // Returns the copy override at the given index, or nullptr. + HloInstruction* GetCopyOverride(const ShapeIndex& index) const { + return ShapeUtil::IsNil(copy_overrides_.shape()) + ? nullptr + : copy_overrides_.element(index); + } + // Records instruction buffer indices which have ambiguous or non-distinct // points-to sets. - tensorflow::Status RecordAmbiguousOrNonDistinctIndices( + Status RecordAmbiguousOrNonDistinctIndices( const TuplePointsToAnalysis& points_to_analysis); - // Records instruction buffer indices which have interferring live ranges + // Records instruction buffer indices which have interfering live ranges // with 'other_instruction' buffers at same index. - tensorflow::Status RecordIndicesWhichInterfereWithOtherInstruction( - BufferLiveness* liveness, HloInstruction* other_instruction); + Status RecordIndicesWhichInterfereWithOtherInstruction( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree* read_only_indices_out); // Recursively inserts copies of 'instruction' tuple elements at indices // specified in 'indices_to_copy', and returns the copy of 'instruction'. @@ -107,28 +156,25 @@ class InstructionCopier { } HloInstruction* instruction_; - std::vector copy_users_; + const std::vector copy_users_; ShapeTree indices_to_copy_; + ShapeTree> control_predecessors_; + ShapeTree read_only_indices_; + ShapeTree copy_overrides_; }; -InstructionCopier::InstructionCopier( - const bool init_value, HloInstruction* instruction, - const std::vector& copy_users) - : instruction_(instruction), - copy_users_(copy_users), - indices_to_copy_(instruction->shape(), init_value) {} - bool InstructionCopier::HasAllIndicesFalse() const { bool all_indices_false = true; - TF_CHECK_OK(indices_to_copy_.ForEachElement([&all_indices_false]( - const ShapeIndex& /*index*/, bool /*is_leaf*/, const bool& data) { - if (data) all_indices_false = false; - return tensorflow::Status::OK(); - })); + indices_to_copy_.ForEachElement( + [&all_indices_false](const ShapeIndex& /*index*/, bool data) { + if (data) { + all_indices_false = false; + } + }); return all_indices_false; } -tensorflow::Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( +Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( const TuplePointsToAnalysis& points_to_analysis) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction_); @@ -141,72 +187,73 @@ tensorflow::Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( // Multiple buffers within a parameter/constant may be live out, so collect // a set of indices at which to copy first. - TF_RETURN_IF_ERROR(points_to.ForEachElement([this]( - const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& buffers) { - for (auto buffer : buffers) { - // pointee is the HloInstruction producing the buffer which may be - // liveout. - HloInstruction* pointee = buffer->instruction(); - if (pointee->opcode() == HloOpcode::kParameter || - pointee->opcode() == HloOpcode::kConstant) { - VLOG(2) << "Parameter or constant buffer " << buffer->ToString() - << " index: " << tensorflow::str_util::Join(index, ",") - << " may be live out of computation: " << pointee->ToString(); - RecordIndex(index); - } - } - return tensorflow::Status::OK(); - })); - return tensorflow::Status::OK(); + points_to.ForEachElement( + [this](const ShapeIndex& index, + const std::vector& buffers) { + if (IsReadOnlyIndex(index)) { + return; + } + for (const LogicalBuffer* buffer : buffers) { + // pointee is the HloInstruction producing the buffer which may be + // liveout. + HloInstruction* pointee = buffer->instruction(); + if (pointee->opcode() == HloOpcode::kParameter || + pointee->opcode() == HloOpcode::kConstant) { + VLOG(2) << "Parameter or constant buffer " << buffer->ToString() + << " index: " << tensorflow::str_util::Join(index, ",") + << " may be live out of computation: " + << pointee->ToString(); + RecordIndex(index); + break; + } + } + }); + return Status::OK(); } -tensorflow::Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( - BufferLiveness* liveness, HloInstruction* other_instruction) { +Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree* read_only_indices_out) { TF_RETURN_IF_ERROR( - RecordAmbiguousOrNonDistinctIndices(liveness->points_to_analysis())); + RecordAmbiguousOrNonDistinctIndices(liveness.points_to_analysis())); TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction( - liveness, other_instruction)); - return tensorflow::Status::OK(); + liveness, other_instruction, read_only_indices_out)); + return Status::OK(); } -tensorflow::Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( +Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( const TuplePointsToAnalysis& points_to_analysis) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction_); // Mapping from LogicalBuffer to index (used to detect non-distinct indices). - std::unordered_map> + FlatMap> buffer_to_source_indices; - TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices]( - const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& buffers) { - if (buffers.size() > 1) { - // Record ambiguous points-to set at 'index'. - if (!indices_to_copy_.element(index)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " at index: " << tensorflow::str_util::Join(index, ",") - << " with ambiguous points-to set."; - RecordIndex(index); - } - } - // For each 'buffer': record a mapping from 'buffer' to 'index'. - for (auto& buffer : buffers) { - auto it = buffer_to_source_indices.find(buffer); - if (it == buffer_to_source_indices.end()) { - buffer_to_source_indices.insert({buffer, std::vector()}); - } - buffer_to_source_indices[buffer].push_back(index); - } - return tensorflow::Status::OK(); - })); + points_to.ForEachElement( + [this, &buffer_to_source_indices]( + const ShapeIndex& index, + const std::vector& buffers) { + if (buffers.size() > 1) { + // Record ambiguous points-to set at 'index'. + if (!indices_to_copy_.element(index)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(index, ",") + << " with ambiguous points-to set."; + RecordIndex(index); + } + } + // For each 'buffer': record a mapping from 'buffer' to 'index'. + for (const LogicalBuffer* buffer : buffers) { + buffer_to_source_indices[buffer].push_back(index); + } + }); // Record all non-distinct indices detected in 'buffer_to_source_indices'. - for (auto& buff_to_src : buffer_to_source_indices) { + for (const auto& buff_to_src : buffer_to_source_indices) { if (buff_to_src.second.size() == 1) { continue; } - for (auto& src_index : buff_to_src.second) { + for (const ShapeIndex& src_index : buff_to_src.second) { // Record non-distinct points-to set at 'src_index'. if (!indices_to_copy_.element(src_index)) { VLOG(2) << "Adding copy of buffer for instruction: " @@ -217,23 +264,26 @@ tensorflow::Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( } } } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status -InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( - BufferLiveness* liveness, HloInstruction* other_instruction) { +Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree* read_only_indices_out) { // Record all buffer indices for 'instruction_', which interfere with // 'other_instruction' at the same index. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape( + ShapeUtil::ForEachSubshape( instruction_->shape(), - [this, &liveness, &other_instruction](const Shape& /*subshape*/, - const ShapeIndex& index) { + [this, &liveness, other_instruction, read_only_indices_out]( + const Shape& /*subshape*/, const ShapeIndex& index) { + if (IsReadOnlyIndex(index)) { + return; + } if (indices_to_copy_.element(index)) { // Return if previous pass already set index. - return tensorflow::Status::OK(); + return; } - auto& points_to_analysis = liveness->points_to_analysis(); + const auto& points_to_analysis = liveness.points_to_analysis(); // Lookup buffers for 'instruction_' and 'other_instruction'. const std::vector instruction_buffers = points_to_analysis.GetPointsToSet(instruction_).element(index); @@ -252,20 +302,24 @@ InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( // then that buffer is not updated on the path between the two // instructions. Therefore, any other (possibly interference-causing) // users of that buffer from 'other_instruction' will see the same data, - // irrespecive of whether we insert a copy of this buffer at + // irrespective of whether we insert a copy of this buffer at // 'instruction_' or not. if (other_instruction_buffers.size() == 1 && other_instruction_buffers[0]->id() == instruction_buffer->id()) { - return tensorflow::Status::OK(); + if (read_only_indices_out != nullptr) { + *read_only_indices_out->mutable_element(index) = true; + } + return; } - // We cant say anything about the ambiguity of 'other_instruction' at + // We can't say anything about the ambiguity of 'other_instruction' at // this point, so we need to check interference between the single // buffer in the points-to set of 'instruction_' and all buffers in // 'other_instruction_buffers'. - for (auto& other_buffer : other_instruction_buffers) { - if (liveness->MayInterfere(*instruction_buffer, *other_buffer)) { + for (const LogicalBuffer* other_buffer : other_instruction_buffers) { + if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) { VLOG(2) << "Adding copy of buffer for instruction: " << instruction_->name() + << " instruction_buffer: " << instruction_buffer->ToString() << " at index: " << tensorflow::str_util::Join(index, ",") << " because of interference with buffer: " << other_buffer->ToString(); @@ -273,40 +327,88 @@ InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( break; } } - return tensorflow::Status::OK(); - })); - return tensorflow::Status::OK(); + }); + return Status::OK(); +} + +// This is called when 'instruction_' is a while body root, and 'parameter' is +// the while body parameter. We record all users of all aliases of 'parameter' +// as control predecessors, so that when we add a copy of 'instruction_', we can +// mark the control dependencies. This is necessary because points-to and +// liveness analysis doesn't know about the aliasing between the while body root +// and param. Without these control dependencies, the copy might get scheduled +// to run at a point that interferes with users of the buffer. +Status InstructionCopier::RecordControlPredecessors( + const TuplePointsToAnalysis& points_to_analysis, + HloInstruction* parameter) { + return indices_to_copy_.ForEachElementWithStatus( + [this, &points_to_analysis, parameter](const ShapeIndex& index, + bool will_copy) { + if (will_copy) { + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + points_to_analysis.GetBufferDefinedAt(parameter, index)); + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + for (HloInstruction* user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + user, points_to_analysis)) { + continue; + } + + if (user != instruction_) { + control_predecessors_.mutable_element(index)->push_back(user); + } + } + } + } + return Status::OK(); + }); } // Recursively inserts copies of 'instruction' tuple element buffers at // indices in 'indices_to_copy_', expanding tuples as needed. -// TODO(b/31159897) Remove superfluous Tuple->GTE->Tuple expressions. HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction, ShapeIndex* index) { - std::vector element_copies; const int64 num_tuple_elements = ShapeUtil::TupleElementCount(instruction->shape()); + std::vector elem_copies(num_tuple_elements); for (int64 i = 0; i < num_tuple_elements; ++i) { - HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, i)); - HloInstruction* element_copy; - index->push_back(i); - if (ShapeUtil::IsTuple(gte->shape())) { - element_copy = CopyTuple(gte, index); + HloInstruction* elem; + if (instruction->opcode() == HloOpcode::kTuple) { + // If the instruction is already a Tuple instruction, we know that the + // element buffers are aliased, so we can just grab the operand directly. + elem = instruction->mutable_operand(i); } else { - if (indices_to_copy_.element(*index)) { - element_copy = gte->parent()->AddInstruction( - HloInstruction::CreateUnary(gte->shape(), HloOpcode::kCopy, gte)); - } else { - element_copy = gte; + // Otherwise we need to add a GTE to unpack the element out of the tuple. + elem = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); + } + index->push_back(i); + if (ShapeUtil::IsTuple(elem->shape())) { + elem_copies[i] = CopyTuple(elem, index); + } else if (!indices_to_copy_.element(*index)) { + elem_copies[i] = elem; + } else if (HloInstruction* copy_override = GetCopyOverride(*index)) { + elem_copies[i] = copy_override; + } else { + HloInstruction* elem_copy = elem->parent()->AddInstruction( + HloInstruction::CreateUnary(elem->shape(), HloOpcode::kCopy, elem)); + for (HloInstruction* control_predecessor : + control_predecessors_.element(*index)) { + VLOG(2) << "Adding control dependency from " + << control_predecessor->ToString() << " to " + << elem_copy->ToString(); + TF_CHECK_OK(control_predecessor->AddControlDependencyTo(elem_copy)); } + elem_copies[i] = elem_copy; } index->pop_back(); - element_copies.push_back(element_copy); } return instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(element_copies)); + HloInstruction::CreateTuple(elem_copies)); } // Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'. @@ -327,8 +429,88 @@ HloInstruction* InstructionCopier::Copy() { return copy; } +// The 'read_only_indices' are initialized based on points-to analysis on the +// while body corresponding to 'while_hlo'. If the init buffer corresponding to +// a read-only index aliases with a constant, it cannot be considered read-only, +// and must be copied. This is necessary because BufferAssignment does not +// currently assign an allocation for constants (b/32248867). +// This function performs this fix-up of 'read_only_indices'. +// +// Returns a ShapeTree of copy_overrides, which implements an optimization to +// allow multiple while loops that share the same read-only constants to +// share a single copy. +StatusOr> RevertReadOnlyIndicesForConstants( + const HloInstruction* while_hlo, + const TuplePointsToAnalysis& points_to_analysis, + ShapeTree* read_only_indices, + FlatMap* shared_copies) { + const HloInstruction* init_hlo = while_hlo->operand(0); + const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); + + // Mapping from LogicalBuffer to index (used to detect non-distinct indices). + FlatSet buffer_set; + + ShapeTree copy_overrides(init_hlo->shape()); + points_to.ForEachElement( + [init_hlo, read_only_indices, shared_copies, &buffer_set, + ©_overrides](const ShapeIndex& index, + const std::vector& buffers) { + // Look for read-only entry parameters. + if (!read_only_indices->element(index)) { + return; + } + for (const LogicalBuffer* buffer : buffers) { + HloInstruction* pointee = buffer->instruction(); + const bool is_constant = pointee->opcode() == HloOpcode::kConstant; + if (!is_constant) { + continue; + } + + // We have found an constant that is read-only in + // the while body. These buffers are managed by the caller, and cannot + // be aliased with HLO buffers. Revert this read-only index, + // to allow it to be copied. + *read_only_indices->mutable_element(index) = false; + + // Optimization to allow multiple while loops that share the same + // read-only entry constants to share a single copy. + // Only unambiguous and distinct array-shaped buffers are allowed, to + // reduce code complexity. The shape of the entry parameter must be + // identical to the shape of the init_hlo at this index, to ensure + // there were no intervening bitcast or GTE instructions, which are + // also hard to handle. + const Shape& pointee_shape = pointee->shape(); + const Shape& init_shape = + ShapeUtil::GetSubshape(init_hlo->shape(), index); + if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && + ShapeUtil::Equal(pointee_shape, init_shape) && + buffer_set.count(buffer) < 1) { + HloInstruction** copy = &(*shared_copies)[pointee]; + if (*copy == nullptr) { + *copy = + pointee->parent()->AddInstruction(HloInstruction::CreateUnary( + pointee_shape, HloOpcode::kCopy, pointee)); + } + // Add the copy as an override. + *copy_overrides.mutable_element(index) = *copy; + } + + // Tracks whether this current buffer is distinct. + buffer_set.insert(buffer); + + // We've already reverted the read-only index and handled the + // single-copy optimization above, so there's nothing more to do. + break; + } + }); + return copy_overrides; +} + } // anonymous namespace +// NOTE: This is only called by gpu::CopyInsertion. It's not called here in the +// base class, since the regular CopyInsertion logic above selectively copies +// tuple elements, while this method assumes all buffers need to be deep copied. StatusOr CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) { auto copy_it = inserted_copies_.find(hlo); if (copy_it == inserted_copies_.end()) { @@ -347,85 +529,96 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN( std::unique_ptr liveness, BufferLiveness::Run(module, MakeUnique(module))); - auto& points_to_analysis = liveness->points_to_analysis(); + const auto& points_to_analysis = liveness->points_to_analysis(); XLA_VLOG_LINES(2, points_to_analysis.ToString()); XLA_VLOG_LINES(2, module->ToString()); - // Gather references to all while body computations in 'module'. - std::unordered_set while_body_computations; - // Gather references to all while instructions in 'module' by computation. - std::unordered_map> - while_instructions; + // Gather all while body computations and while instructions. + FlatSet while_body_computations; + std::vector while_instructions; for (auto& computation : module->computations()) { for (auto& instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kWhile) { - continue; + if (instruction->opcode() == HloOpcode::kWhile) { + while_body_computations.insert(instruction->while_body()); + while_instructions.push_back(instruction.get()); } - while_body_computations.insert(instruction->while_body()); - auto it = while_instructions.find(computation.get()); - if (it == while_instructions.end()) { - while_instructions.insert( - {computation.get(), std::vector()}); - } - while_instructions[computation.get()].emplace_back(instruction.get()); } } + // Collect instruction buffer indices to copy in 'instructions_to_copy'. + std::vector instructions_to_copy; + + // Add copies of computation root instructions, if needed. + FlatMap> while_body_read_only_indices; for (auto& computation : module->computations()) { VLOG(2) << "computation " << computation->name(); - - // Collect instruction buffer indices to copy in 'instructions_to_copy'. - std::vector instructions_to_copy; - - // Add copies of while 'init' operand instructions (if needed). - // TODO(b/33301720) Remove redundant while instruction copies. - auto it = while_instructions.find(computation.get()); - if (it != while_instructions.end()) { - for (auto& while_hlo : it->second) { - // Create InstructionCopier for init operand of while instruction. - HloInstruction* init_hlo = while_hlo->mutable_operand(0); - instructions_to_copy.push_back( - InstructionCopier(/*init_value=*/false, init_hlo, {while_hlo})); - InstructionCopier& init_copier = instructions_to_copy.back(); - // Record 'init' buffer indices which point-to a Constant or Parameter. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( - liveness->points_to_analysis())); - // Record indices necessary to colocate while and init operand buffers. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( - liveness.get(), while_hlo)); - } - } - - // Create InstructionCopier for computation root instruction. - instructions_to_copy.push_back(InstructionCopier( - /*init_value=*/false, computation->root_instruction(), {})); - InstructionCopier& root_copier = instructions_to_copy.back(); - + InstructionCopier root_copier(computation->root_instruction(), + /*copy_users=*/{}); if (while_body_computations.count(computation.get()) > 0) { - // Record root indices to copy for while body sub-computations. - // We do not need to call RecordIndicesWhichPointToParamOrConstant for - // the while root instruction here, because any neccessary copies needed - // to avoid constant or parameters in the output are handled by while.init - // operand copy insertion above (which will share an allocation). + // Record root indices to copy for while body sub-computations. We do not + // need to call RecordIndicesWhichPointToParamOrConstant for the while + // body root instruction here, because any necessary copies needed to + // avoid constants or parameters in the output are handled by while.init + // operand copy insertion below (which will share an allocation). + HloInstruction* while_body_param = computation->parameter_instruction(0); + ShapeTree read_only_indices(while_body_param->shape()); TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( - liveness.get(), computation->parameter_instruction(0))); + *liveness, while_body_param, &read_only_indices)); + while_body_read_only_indices[computation.get()] = read_only_indices; + + // Mark control predecessors, based on the body param, for any copies + // we'll be inserting. This ensures the copy doesn't run too early. + TF_RETURN_IF_ERROR(root_copier.RecordControlPredecessors( + points_to_analysis, while_body_param)); } else { // Record root indices to copy for general computations. TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant( - liveness->points_to_analysis())); + points_to_analysis)); } + instructions_to_copy.push_back(root_copier); + } - for (auto& to_copy : instructions_to_copy) { - if (to_copy.HasAllIndicesFalse()) { - continue; - } - changed = true; + // Add copies of while 'init' operand instructions, if needed. 'shared_copies' + // is used to ensure that multiple while loops can share a single copy of the + // same entry parameter or constant, if all loops use it read-only. + // + // TODO(b/33301720) Remove redundant while instruction copies. + FlatMap shared_copies; + for (HloInstruction* while_hlo : while_instructions) { + // Fix read_only_indices to account for entry constants. Also + // initialize copy_overrides, which ensures a single copy for each read-only + // constant that is used in multiple while loops. + ShapeTree* read_only_indices = + &while_body_read_only_indices[while_hlo->while_body()]; + TF_ASSIGN_OR_RETURN( + const ShapeTree copy_overrides, + RevertReadOnlyIndicesForConstants(while_hlo, points_to_analysis, + read_only_indices, &shared_copies)); + // Create InstructionCopier for init operand of while instruction. + HloInstruction* init_hlo = while_hlo->mutable_operand(0); + InstructionCopier init_copier(init_hlo, {while_hlo}); + init_copier.SetReadOnlyIndices(*read_only_indices); + init_copier.SetCopyOverrides(copy_overrides); + // Record 'init' buffer indices which point-to a Constant or Parameter. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( + points_to_analysis)); + // Record indices necessary to colocate while and init operand buffers. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( + *liveness, while_hlo, /*read_only_indices_out=*/nullptr)); + instructions_to_copy.push_back(init_copier); + } - // Copy instruction at recorded buffer indices. - HloInstruction* copy = to_copy.Copy(); - if (to_copy.instruction() == computation->root_instruction()) { - computation->set_root_instruction(copy); - } + for (InstructionCopier& to_copy : instructions_to_copy) { + if (to_copy.HasAllIndicesFalse()) { + continue; + } + changed = true; + + // Copy instruction at recorded buffer indices. + HloComputation* computation = to_copy.instruction()->parent(); + HloInstruction* copy = to_copy.Copy(); + if (to_copy.instruction() == computation->root_instruction()) { + computation->set_root_instruction(copy); } } diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 5bf6f2501b1..28bb62e40c7 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -32,7 +33,6 @@ namespace xla { // different lifetimes than computation results. class CopyInsertion : public HloPassInterface { public: - ~CopyInsertion() override {} tensorflow::StringPiece name() const override { return "copy-insertion"; } // Run the pass on the given module. Returns whether the module was changed @@ -46,7 +46,7 @@ class CopyInsertion : public HloPassInterface { // A map containing all copies inserted during the copy insertion pass. The // key is the copied instruction and the value is the copy. - std::unordered_map inserted_copies_; + tensorflow::gtl::FlatMap inserted_copies_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index e64da58dc79..cc77339bb63 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -20,18 +20,23 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xla/test_helpers.h" +namespace op = xla::testing::opcode_matchers; namespace xla { namespace { +using ::testing::UnorderedElementsAre; + class CopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { @@ -39,59 +44,27 @@ class CopyInsertionTest : public HloTestBase { EXPECT_IS_OK(copy_insertion.Run(module).status()); // Verify the points to set of the root of the computation after copy - // insertion contains no constants or parameters. + // insertion contains no constants or parameters, and is distinct and + // non-ambiguous. auto points_to_analysis = TuplePointsToAnalysis::Run(module).ConsumeValueOrDie(); - const std::set maybe_live_out_buffers = + const auto& points_to = points_to_analysis->GetPointsToSet( + module->entry_computation()->root_instruction()); + EXPECT_TRUE(points_to.IsDistinct()); + EXPECT_TRUE(!points_to.IsAmbiguous()); + + tensorflow::gtl::FlatSet maybe_live_out_buffers = points_to_analysis ->GetPointsToSet(module->entry_computation()->root_instruction()) .CreateFlattenedSet(); + for (const LogicalBuffer* buffer : maybe_live_out_buffers) { EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant); EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); } } - - // OperandTree is a test helper class that simplifies the expression of - // an expected tree of operands (starting at some root instruction) in a - // unit test. - // Each HLO instruction is represented as a node in the OperandTree. - struct OperandTree { - // The expected opcode for this OperandTree node. - HloOpcode opcode; - // The set of operands expected for this OperandTree node. - std::vector operands; - // If non-null, a pointer to the expected HloInstruction at this node. - const HloInstruction* instruction = nullptr; - - // Returns a mutable reference to operand 'i' of this node. - OperandTree& op(int i) { - if (i >= operands.size()) { - operands.resize(i + 1); - } - return operands[i]; - } - - // Check that 'instruction' and its operands match expected values recorded - // in OperandTree. - void Check(const HloInstruction* instruction) { - EXPECT_EQ(opcode, instruction->opcode()); - if (instruction != nullptr) { - EXPECT_EQ(instruction, instruction); - } - if (operands.empty()) { - return; - } - EXPECT_EQ(operands.size(), instruction->operand_count()); - for (int i = 0; i < instruction->operand_count(); ++i) { - operands[i].Check(instruction->operand(i)); - } - } - }; }; -#define EXPECT_INST(A, E...) EXPECT_EQ(A, (std::set{E})) - TEST_F(CopyInsertionTest, SingleParameter) { auto builder = HloComputation::Builder(TestName()); HloInstruction* x = builder.AddInstruction( @@ -99,25 +72,16 @@ TEST_F(CopyInsertionTest, SingleParameter) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({x})); - EXPECT_INST(x->users(), tuple); + EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(old_root->operand(0)))); } TEST_F(CopyInsertionTest, SingleConstant) { @@ -127,25 +91,16 @@ TEST_F(CopyInsertionTest, SingleConstant) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); - EXPECT_INST(constant->users(), tuple); + EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(old_root->operand(0)))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -169,35 +124,15 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); - // "constant2" and parameter "x" are pointed to by the tuple and should be - // copied. - - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).instruction = old_root; - - op_tree.op(2).opcode = HloOpcode::kGetTupleElement; - op_tree.op(2).op(0).opcode = HloOpcode::kTuple; - op_tree.op(2).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(old_root->operand(0)), + op::Copy(old_root->operand(1)), old_root->operand(2))); } TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { @@ -221,32 +156,19 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); - EXPECT_INST(constant1->users(), tuple1); - EXPECT_INST(constant2->users(), tuple1, tuple2); - EXPECT_INST(constant3->users(), tuple2); + EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1)); + EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); + EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kSelect; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kSelect; - op_tree.op(1).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(old_root)), + op::Copy(op::GetTupleElement(old_root)))); } TEST_F(CopyInsertionTest, BitcastParameter) { @@ -258,22 +180,16 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); - EXPECT_INST(x->users(), bitcast); + EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kCopy; - op_tree.op(0).opcode = HloOpcode::kBitcast; - op_tree.op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Copy(old_root)); } TEST_F(CopyInsertionTest, BitcastConstant) { @@ -286,22 +202,16 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); - EXPECT_INST(constant->users(), bitcast); + EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kCopy; - op_tree.op(0).opcode = HloOpcode::kBitcast; - op_tree.op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Copy(old_root)); } TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { @@ -313,25 +223,16 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); - EXPECT_EQ(1, x->user_count()); - EXPECT_EQ(*x->users().begin(), bitcast); + EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(old_root->operand(0)))); } TEST_F(CopyInsertionTest, NestedTupleParameter) { @@ -342,47 +243,31 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { // Param shape is: ((F32[], S32[1,2,3]), F32[42]) builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), - ShapeUtil::MakeShape(S32, {1, 2, 3})}), - ShapeUtil::MakeShape(F32, {42})}), + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {1, 2, 3})}), + ShapeUtil::MakeShape(F32, {42})}), "param0")); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(HloOpcode::kParameter, - module.entry_computation()->root_instruction()->opcode()); + module->entry_computation()->root_instruction()->opcode()); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); + HloInstruction* new_root = module->entry_computation()->root_instruction(); EXPECT_NE(old_root, new_root); - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).op(0).op(0).opcode = HloOpcode::kParameter; - op_tree.op(0).op(0).op(0).op(0).op(0).instruction = old_root; - - op_tree.op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(1).opcode = HloOpcode::kCopy; - op_tree.op(0).op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(1).op(0).op(0).op(0).opcode = HloOpcode::kParameter; - op_tree.op(0).op(1).op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kParameter; - op_tree.op(1).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT( + new_root, + op::Tuple( + op::Tuple( + op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))), + op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))), + op::Copy(op::GetTupleElement(old_root)))); } TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { @@ -392,10 +277,11 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { // Param shape is: ((F32[], S32[1,2,3]), F32[42]) auto param = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), - ShapeUtil::MakeShape(S32, {1, 2, 3})}), - ShapeUtil::MakeShape(F32, {42})}), + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {1, 2, 3})}), + ShapeUtil::MakeShape(F32, {42})}), "param0")); // The return value of the computation is the zero-th elemnt of the nested @@ -403,30 +289,17 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); - EXPECT_EQ(gte, module.entry_computation()->root_instruction()); + EXPECT_EQ(gte, module->entry_computation()->root_instruction()); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(old_root)), + op::Copy(op::GetTupleElement(old_root)))); } TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { @@ -452,27 +325,21 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); - EXPECT_EQ(gte, module.entry_computation()->root_instruction()); + EXPECT_EQ(gte, module->entry_computation()->root_instruction()); - HloInstruction* old_root = module.entry_computation()->root_instruction(); - InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); + HloInstruction* old_root = module->entry_computation()->root_instruction(); + InsertCopies(module.get()); - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kCopy; - op_tree.op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Copy(old_root)); } class WhileCopyInsertionTest : public CopyInsertionTest { protected: - WhileCopyInsertionTest() : module_(TestName()) {} + WhileCopyInsertionTest() : module_(CreateNewModule()) {} // Builds a While condition computation which reads the induction variable // from the tuple parameter, and returns a predicate indicating whether this @@ -530,8 +397,48 @@ class WhileCopyInsertionTest : public CopyInsertionTest { return builder.Build(); } - // Builds a While body computation with read-only tuple element 0. + // Builds a While body computation with two output tuple elements dependent on // both input tuple elements. + // + // EX: Body({in0, in1, in2}) + // out0 = Add(in0, 1) + // out1 = in1 + // out2 = in2 + // Tuple(out0, out1, out2) + std::unique_ptr BuildDependentBodyComputation2() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); + + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + + // Update the induction variable GTE(0). + auto induction_variable = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + induction_variable_shape_, loop_state, 0)); + auto inc = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + + // add0 = Add(in0, 1) + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); + // data1 = GTE(1). + HloInstruction* data1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + + // data2 = GTE(2). + HloInstruction* data2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2)); + + // Create output Tuple. + builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2})); + + return builder.Build(); + } + + // Builds a While body computation with read-only tuple element 0. // EX: // Body({in0, in1}) // out0 = in0 @@ -549,6 +456,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // Update data GTE(1). auto data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + // Use 'induction_variable' in computation with no path to output tuple. auto update = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); @@ -566,11 +474,15 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // out0 = Add(in0, 1) // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) // Tuple(out0, out1) - std::unique_ptr BuildIndependentBodyComputation() { + std::unique_ptr BuildIndependentBodyComputation( + bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".Body"); // Create param instruction to access loop state. + const Shape& loop_state_shape = + nested ? nested_loop_state_shape_ : loop_state_shape_; + auto loop_state = builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); // Update the induction variable GTE(0). auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( @@ -581,16 +493,30 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(1). - auto data = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + HloInstruction* data = nullptr; + if (nested) { + data = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + nested_tuple_shape_, loop_state, 1)); + data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, data, 0)); + } else { + data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + } auto update = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1( {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); - // add0 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) + // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + if (nested) { + auto nested_tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add1, add1})); + builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple})); + } else { + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + } return builder.Build(); } @@ -643,8 +569,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // Builds a While instruction using 'condition' and 'body' sub-computations. // Init operand is initialized to zeros of appropriate shape. - void BuildWhileInstruction(HloComputation* condition, HloComputation* body, - bool nested = false) { + HloInstruction* BuildWhileInstruction(HloComputation* condition, + HloComputation* body, + bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".While"); auto induction_var_init = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); @@ -658,17 +585,18 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction::CreateTuple({data_init, data_init})); auto loop_state_init = builder.AddInstruction( HloInstruction::CreateTuple({induction_var_init, inner_init})); - builder.AddInstruction(HloInstruction::CreateWhile( + auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition, body, loop_state_init)); - module_.AddEntryComputation(builder.Build()); - return; + module_->AddEntryComputation(builder.Build()); + return while_hlo; } auto loop_state_init = builder.AddInstruction( HloInstruction::CreateTuple({induction_var_init, data_init})); - builder.AddInstruction(HloInstruction::CreateWhile( + auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition, body, loop_state_init)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); + return while_hlo; } HloInstruction* BuildWhileInstruction_InitPointsToConstant() { @@ -746,21 +674,23 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstructionWithCustomInit( const Shape& loop_state_shape, HloInstruction* data_init, HloComputation::Builder* builder) { + const bool nested = + ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto condition = - module_.AddEmbeddedComputation(BuildConditionComputation()); - auto body = - module_.AddEmbeddedComputation(BuildIndependentBodyComputation()); + module_->AddEmbeddedComputation(BuildConditionComputation(nested)); + auto body = module_->AddEmbeddedComputation( + BuildIndependentBodyComputation(nested)); auto loop_state_init = builder->AddInstruction( HloInstruction::CreateTuple({induction_var_init, data_init})); auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile( loop_state_shape, condition, body, loop_state_init)); - module_.AddEntryComputation(builder->Build()); + module_->AddEntryComputation(builder->Build()); return while_hlo; } - HloModule module_; + std::unique_ptr module_; Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {}); Shape data_shape_ = ShapeUtil::MakeShape(F32, {8}); Shape loop_state_shape_ = @@ -782,16 +712,23 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // CopyInsertion pass should not generate any copies. // TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { - auto condition = module_.AddEmbeddedComputation(BuildConditionComputation()); - auto body = module_.AddEmbeddedComputation(BuildIndependentBodyComputation()); - BuildWhileInstruction(condition, body); + auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto body = + module_->AddEmbeddedComputation(BuildIndependentBodyComputation()); + auto while_hlo = BuildWhileInstruction(condition, body); + const HloInstruction* old_init = while_hlo->operand(0); HloInstruction* old_root = body->root_instruction(); - InsertCopies(&module_); + InsertCopies(module_.get()); HloInstruction* new_root = body->root_instruction(); + const HloInstruction* new_init = while_hlo->operand(0); // No copies should be inserted so root should not be updated. - CHECK_EQ(old_root, new_root); + EXPECT_EQ(old_root, new_root); + + // Both init indices need copies. + EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while body computation with dependent tuple elements: @@ -801,39 +738,25 @@ TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { // out1 = Add(BCast(in0), in1) // Tuple(out0, out1) // -// CopyInsertion pass should generate: +// CopyInsertion pass should convert the root instruction to: // -// Tuple // old root -// / \ -// GTE(0) GTE(1) -// | | -// Copy | -// \ / -// Tuple // new root +// Tuple(Copy(out0), out1) // TEST_F(WhileCopyInsertionTest, DependentTupleElements) { - auto condition = module_.AddEmbeddedComputation(BuildConditionComputation()); - auto body = module_.AddEmbeddedComputation(BuildDependentBodyComputation()); - BuildWhileInstruction(condition, body); + auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation()); + auto while_hlo = BuildWhileInstruction(condition, body); + const HloInstruction* old_init = while_hlo->operand(0); HloInstruction* old_root = body->root_instruction(); - InsertCopies(&module_); + InsertCopies(module_.get()); HloInstruction* new_root = body->root_instruction(); + const HloInstruction* new_init = while_hlo->operand(0); - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(new_root, + op::Tuple(op::Copy(old_root->operand(0)), old_root->operand(1))); + EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while body computation with read-only tuple element 0: @@ -849,20 +772,113 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) { // \ / // TUPLE (root) // -// CopyInsertion pass should not generate any copies. -// +// CopyInsertion pass should not generate any copies for the while body. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { - auto condition = module_.AddEmbeddedComputation(BuildConditionComputation()); - auto body = module_.AddEmbeddedComputation( + auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto body = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); - BuildWhileInstruction(condition, body); + auto while_hlo = BuildWhileInstruction(condition, body); + const HloInstruction* old_init = while_hlo->operand(0); HloInstruction* old_root = body->root_instruction(); - InsertCopies(&module_); + InsertCopies(module_.get()); HloInstruction* new_root = body->root_instruction(); + const HloInstruction* new_init = while_hlo->operand(0); - // No copies should be inserted so root should not be updated. - CHECK_EQ(old_root, new_root); + // No copies should be inserted in the body, so root should not be updated. + EXPECT_EQ(old_root, new_root); + + // Both indices need copies, even though Index 0 is read-only, since both are + // constants, which must be copied. + EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); +} + +// Same as above, but with two while loops, sharing entry parameters. +TEST_F(WhileCopyInsertionTest, + DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) { + auto condition1 = + module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = + module_->AddEmbeddedComputation(BuildConditionComputation()); + auto body1 = module_->AddEmbeddedComputation( + BuildDependentBodyOneReadOnlyComputation()); + auto body2 = module_->AddEmbeddedComputation( + BuildDependentBodyOneReadOnlyComputation()); + + auto builder = HloComputation::Builder(TestName() + ".While"); + auto iter_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "data")); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({iter_param, data_param})); + + auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition1, body1, loop_init)); + auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition2, body2, loop_init)); + module_->AddEntryComputation(builder.Build()); + + InsertCopies(module_.get()); + + // Both while loops alias iter_param, since index 0 is read-only in the body. + EXPECT_EQ(while_hlo1->operand(0)->operand(0), + while_hlo2->operand(0)->operand(0)); + EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_param); + + // Each while loop gets its own copy of data_param, since index 1 is not + // read-only in the body. + EXPECT_NE(while_hlo1->operand(0)->operand(1), + while_hlo2->operand(0)->operand(1)); + EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_param)); + EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_param)); +} + +// Same as above, but with two while loops, sharing non-parameters. +TEST_F(WhileCopyInsertionTest, + DependentTupleElements_OneReadOnly_TwoLoops_NonParams) { + auto condition1 = + module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = + module_->AddEmbeddedComputation(BuildConditionComputation()); + auto body1 = module_->AddEmbeddedComputation( + BuildDependentBodyOneReadOnlyComputation()); + auto body2 = module_->AddEmbeddedComputation( + BuildDependentBodyOneReadOnlyComputation()); + + auto builder = HloComputation::Builder(TestName() + ".While"); + auto iter_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "data")); + // Add dummy ops to ensure loop_init elements aren't entry parameters. + auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary( + iter_param->shape(), HloOpcode::kExp, iter_param)); + auto data_value = builder.AddInstruction(HloInstruction::CreateUnary( + data_param->shape(), HloOpcode::kExp, data_param)); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({iter_value, data_value})); + + auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition1, body1, loop_init)); + auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition2, body2, loop_init)); + module_->AddEntryComputation(builder.Build()); + + InsertCopies(module_.get()); + + // No copies of iter_value are necessary, since index 0 is read-only in both + // while bodies. + EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_value); + EXPECT_EQ(while_hlo2->operand(0)->operand(0), iter_value); + + // Each while loop gets its own copy of data_value, since index 1 is not + // read-only in the body. + EXPECT_NE(while_hlo1->operand(0)->operand(1), + while_hlo2->operand(0)->operand(1)); + EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_value)); + EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_value)); } // Tests while body computation with nested tuple elements: @@ -875,7 +891,8 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { // Add Reverse // | | // -// CopyInsertion pass should generate: +// CopyInsertion pass will conceptually generate the following, but with the +// actual GTE and Tuple instructions optimized away: // // Tuple // old root // / \ @@ -895,110 +912,47 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { // TEST_F(WhileCopyInsertionTest, NestedTupleElements) { auto condition = - module_.AddEmbeddedComputation(BuildConditionComputation(true)); - auto body = module_.AddEmbeddedComputation(BuildNestedBodyComputation()); + module_->AddEmbeddedComputation(BuildConditionComputation(true)); + auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation()); BuildWhileInstruction(condition, body, true); HloInstruction* old_root = body->root_instruction(); - InsertCopies(&module_); - HloInstruction* new_root = body->root_instruction(); + InsertCopies(module_.get()); - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kTuple; - - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(body->root_instruction(), + op::Tuple(old_root->operand(0), + op::Tuple(old_root->operand(1)->operand(0), + op::Copy(old_root->operand(1)->operand(1))))); } // Tests while init instruction which points-to a constant. // // init = Tuple(Constant(S32, {}), Constant(F32, {8})) // -// CopyInsertion pass should generate: -// -// Tuple // old init -// / \ -// GTE(0) GTE(1) -// | | -// Copy Copy -// \ / -// Tuple // new init +// CopyInsertion pass should add copies for both constants. // TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); auto old_init = while_hlo->operand(0); - InsertCopies(&module_); - auto new_init = while_hlo->operand(0); + InsertCopies(module_.get()); - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).instruction = old_init; - - op_tree.Check(new_init); + EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while init instruction which points-to a parameter. // // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) // -// CopyInsertion pass should generate: -// -// Tuple // old init -// / \ -// GTE(0) GTE(1) -// | | -// Copy Copy -// \ / -// Tuple // new init +// CopyInsertion pass should add copies for both the constant and parameter. // TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); auto old_init = while_hlo->operand(0); - InsertCopies(&module_); - auto new_init = while_hlo->operand(0); + InsertCopies(module_.get()); - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).instruction = old_init; - - op_tree.Check(new_init); + EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while init instruction which has an ambiguous points-to set. @@ -1006,7 +960,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { // select = Select(pred, tuple1, tuple2) // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) // -// CopyInsertion pass should generate: +// CopyInsertion pass will conceptually generate the following, but with some of +// the actual GTE and Tuple instructions optimized away: // // Tuple // old init // / \ @@ -1027,40 +982,22 @@ TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); auto old_init = while_hlo->operand(0); - InsertCopies(&module_); - auto new_init = while_hlo->operand(0); + InsertCopies(module_.get()); - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).opcode = HloOpcode::kTuple; - - op_tree.op(1).op(0).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_init; - - op_tree.Check(new_init); + EXPECT_THAT( + while_hlo->operand(0), + op::Tuple( + op::Copy(old_init->operand(0)), + op::Tuple(op::Copy(op::GetTupleElement(old_init->operand(1))), + op::Copy(op::GetTupleElement(old_init->operand(1)))))); } // Tests while init instruction which has a non-distinct points-to set. // // init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one})) // -// CopyInsertion pass should generate: +// CopyInsertion pass will conceptually generate the following, but with some of +// the actual GTE and Tuple instructions optimized away: // // Tuple // old init // / \ @@ -1081,73 +1018,116 @@ TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); auto old_init = while_hlo->operand(0); - InsertCopies(&module_); - auto new_init = while_hlo->operand(0); + InsertCopies(module_.get()); - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).opcode = HloOpcode::kTuple; - - op_tree.op(1).op(0).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_init; - - op_tree.Check(new_init); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(old_init->operand(0)), + op::Tuple(op::Copy(old_init->operand(1)->operand(0)), + op::Copy(old_init->operand(1)->operand(0))))); } -// Tests while init instruction buffer which interfers with while result buffer. +// Tests while init instruction buffer which interferes with while result +// buffer. // // init_data = Broadcast(...) // add_unrelated = Add(init_data) // takes a reference to cause interference // init = Tuple(Constant(S32, {}), init_data)) // -// CopyInsertion pass should generate: -// -// Tuple // old init -// / \ -// GTE(0) GTE(1) -// | | -// Copy Copy -// \ / -// Tuple // new init +// CopyInsertion pass should copy both operands. // TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); auto old_init = while_hlo->operand(0); - InsertCopies(&module_); - auto new_init = while_hlo->operand(0); + InsertCopies(module_.get()); - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; + EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); +} - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; +// Tests while init instruction buffer which has a non-distinct points-to set: +// +// init = Tuple(Parameter(S32, {}), Parameter(F32, {8}, +// Parameter(F32, {8}))) +// +// where the second and third parameters are identical *and* the tuple shared +// by another while instruction.. +// +// Verifies that the resulting point-to set is distinct in the resulting Tuple +// (non-identical Copys). In other words, verifies that copy sharing does not +// insert identical copies to the resulting tuple. +TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { + auto condition1 = + module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = + module_->AddEmbeddedComputation(BuildConditionComputation()); + // Loop body that outputs tuple comprises two elements dependent on the init + // tuple. + auto body1 = + module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); + auto body2 = + module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).instruction = old_init; + auto builder = HloComputation::Builder(TestName() + ".While"); - op_tree.Check(new_init); + auto iter_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "data")); + + // Loop init tuple contains two identical parameter buffers. + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({iter_param, data_param, data_param})); + + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); + + // Two while loops shares the same loop init tuple. + auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition1, body1, loop_init)); + auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition2, body2, loop_init)); + + module_->AddEntryComputation(builder.Build()); + + auto points_to_analysis = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + + // Asserts that the init tuples before copy insertion is non-distinct. + ASSERT_FALSE( + points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct()); + ASSERT_FALSE( + points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct()); + + auto old_init1 = while_hlo1->operand(0); + auto old_init2 = while_hlo2->operand(0); + + InsertCopies(module_.get()); + + EXPECT_THAT(while_hlo1->operand(0), + op::Tuple(op::Copy(old_init1->operand(0)), + op::Copy(old_init1->operand(1)), + op::Copy(old_init1->operand(2)))); + + EXPECT_THAT(while_hlo2->operand(0), + op::Tuple(op::Copy(old_init2->operand(0)), + op::Copy(old_init2->operand(1)), + op::Copy(old_init2->operand(2)))); + + // Verifies the init tuples after copy insertion is distinct. + points_to_analysis = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + const auto& points_to1 = + points_to_analysis->GetPointsToSet(while_hlo1->operand(0)); + EXPECT_TRUE(points_to1.IsDistinct()); + + const auto& points_to2 = + points_to_analysis->GetPointsToSet(while_hlo2->operand(0)); + EXPECT_TRUE(points_to2.IsDistinct()); } } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 3d2df5a459b..51ecbccd494 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -53,21 +53,24 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/port:initialize", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", - "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:hlo_proto_util", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", @@ -95,6 +98,7 @@ cc_library( name = "simple_orc_jit", srcs = ["simple_orc_jit.cc"], hdrs = ["simple_orc_jit.h"], + linkopts = ["-ldl"], deps = [ ":compiler_functor", ":cpu_runtime", @@ -135,7 +139,6 @@ cc_library( "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", @@ -163,7 +166,6 @@ cc_library( "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:lib", @@ -412,6 +414,7 @@ cc_test( cc_test( name = "infeed_manager_test", + size = "small", srcs = ["infeed_manager_test.cc"], deps = [ ":cpu_runtime", @@ -504,6 +507,7 @@ cc_library( cc_test( name = "conv_canonicalization_test", + size = "small", srcs = ["conv_canonicalization_test.cc"], deps = [ ":conv_canonicalization", @@ -511,7 +515,6 @@ cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 89b3302bca0..8ebf9ab110d 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -81,12 +81,17 @@ operator()(llvm::Module& module) const { // Run optimization passes on module. function_passes.doInitialization(); + + CHECK(!llvm::verifyModule(module, &llvm::dbgs())); + for (auto func = module.begin(); func != module.end(); ++func) { function_passes.run(*func); } function_passes.doFinalization(); module_passes.run(module); + CHECK(!llvm::verifyModule(module, &llvm::dbgs())); + // Buffer for holding machine code prior to constructing the ObjectFile. llvm::SmallVector stream_buffer; llvm::raw_svector_ostream ostream(stream_buffer); @@ -192,8 +197,6 @@ void CompilerFunctor::AddOptimizationPasses( module_passes->add(createTargetTransformInfoWrapperPass( target_machine_->getTargetIRAnalysis())); - module_passes->add(llvm::createVerifierPass()); - llvm::PassManagerBuilder builder; builder.OptLevel = opt_level_; builder.SizeLevel = 0; @@ -212,8 +215,6 @@ void CompilerFunctor::AddOptimizationPasses( builder.populateFunctionPassManager(*function_passes); builder.populateModulePassManager(*module_passes); - - module_passes->add(llvm::createVerifierPass()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index d18141af83e..f5ad431277d 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" @@ -28,6 +29,8 @@ limitations under the License. namespace xla { namespace cpu { +using ::testing::ElementsAre; + class ConvCanonicalizationTest : public HloTestBase { public: ConvCanonicalizationTest() { @@ -78,7 +81,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), input, kernel, conv_window_, dnums)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); @@ -96,17 +99,14 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { // The input is in CNHW order. input_reshape should produce // NHWC for the convolution to hit the Eigen fast path. - EXPECT_TRUE(ContainersEqual(input_reshape->dimensions(), - std::vector({1, 2, 3, 0}))); + EXPECT_THAT(input_reshape->dimensions(), ElementsAre(1, 2, 3, 0)); // The kernel is in OIHW order. kernel_reshape should produce // HWIO for the convolution to hit the Eigen fast path. - EXPECT_TRUE(ContainersEqual(kernel_reshape->dimensions(), - std::vector({2, 3, 1, 0}))); + EXPECT_THAT(kernel_reshape->dimensions(), ElementsAre(2, 3, 1, 0)); // The output of the canonical convolution is in NHWC order (the same as // input_reshape's order). output_reshape should restore that order to the // order of the computation root (CNHW). - EXPECT_TRUE(ContainersEqual(output_reshape->dimensions(), - std::vector({3, 0, 1, 2}))); + EXPECT_THAT(output_reshape->dimensions(), ElementsAre(3, 0, 1, 2)); } TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { @@ -135,7 +135,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), input, kernel, conv_window_, dnums)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); ConvCanonicalization conv_canonicalization; @@ -144,3 +144,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { } // namespace cpu } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 712c180f95f..34b99f2440b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include // NOLINT(build/c++11): only using std::call_once, not mutex. #include #include #include @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/port/initialize.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -58,7 +58,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -66,7 +69,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" @@ -140,22 +145,54 @@ CpuCompiler::CpuCompiler() { LLVMInitializePowerPCTargetMC(); LLVMInitializePowerPCAsmPrinter(); LLVMInitializePowerPCDisassembler(); +} - // LLVM command-line flags are global, so set them during initialization. - legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); - if (!flags->xla_cpu_llvm_cl_opts.empty()) { - std::vector opts = - tensorflow::str_util::Split(flags->xla_cpu_llvm_cl_opts, ','); +namespace { + +const char* kXlaParallelCpuOption = "xla_cpu_parallel"; + +// LLVM makes certain options configurable only through its command-line +// options; it provide the ParseCommandLineOptions function that lets us set +// flags at runtime. However, since these flags are global we want to avoid +// multiple invocations of the LLVM compilation pipeline with a different set of +// flags. Therefore, we only pass command-line flags to LLVM once, before the +// first module is compiled. +std::once_flag llvm_command_line_options_initialized; + +void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { + auto options = config.debug_options().xla_backend_extra_options(); + if (!options.empty()) { + std::vector fake_argv_storage; + fake_argv_storage.push_back(""); + for (const auto& it : options) { + // Skip options the XLA backend itself consumes. + if (it.first != kXlaParallelCpuOption) { + if (it.second.empty()) { + fake_argv_storage.push_back(it.first); + } else { + fake_argv_storage.push_back(it.first + "=" + it.second); + } + } + } + + VLOG(2) << "Passing argv to LLVM:"; std::vector fake_argv; - fake_argv.push_back(""); - for (const string& opt : opts) { - fake_argv.push_back(opt.c_str()); + for (const auto& s : fake_argv_storage) { + fake_argv.push_back(s.c_str()); + VLOG(2) << s; } llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); } } -namespace { +// Helps determine whether the parallel CPU backend was requested in the options +// of this module configuration. +bool CpuParallelBackendRequested(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaParallelCpuOption) > 0; +} + // This visitor records which HLO instructions should have profiling information // recorded. class CollectProfileCandidates : public DfsHloVisitorWithDefault { @@ -190,16 +227,16 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { } // It is important to recurse for "while" or else we risk overly coarse // profiling information. - Status HandleWhile(HloInstruction* xla_while, HloInstruction* /*init*/, - HloComputation* condition, HloComputation* body) override { + Status HandleWhile(HloInstruction* xla_while) override { TF_RETURN_IF_ERROR(DefaultAction(xla_while)); CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_); - TF_RETURN_IF_ERROR( - condition->root_instruction()->Accept(&candidates_for_condition)); + TF_RETURN_IF_ERROR(xla_while->while_condition()->root_instruction()->Accept( + &candidates_for_condition)); CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_); - TF_RETURN_IF_ERROR(body->root_instruction()->Accept(&candidates_for_body)); + TF_RETURN_IF_ERROR(xla_while->while_body()->root_instruction()->Accept( + &candidates_for_body)); return Status::OK(); } @@ -208,64 +245,86 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* hlo_module, - HloModuleConfig* module_config, - HloDumper dump_hlo) { +Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { // Optimization pipeline. HloPassPipeline pipeline("CPU", dump_hlo); - pipeline.AddPass(); + pipeline.AddInvariantChecker(); + + // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding + // where we will take this pass in future. + // pipeline.AddPass(); + pipeline.AddPass(); { auto& pass = pipeline.AddPass>("simplification", dump_hlo); pass.AddPass( /*is_layout_sensitive=*/false, - [](const Shape&, const Shape&) { return false; }); + [](const Shape&, const Shape&) { return false; }, + /*enable_dot_simplification=*/false); pass.AddPass(); + pass.AddPass(); } - pipeline.AddPass(PotentiallyImplementedAsEigenDot); - pipeline.AddPass(); + pipeline.AddPass( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return PotentiallyImplementedAsEigenDot(dot) + ? candidate_operands + : TransposeFolding::OperandIndices{}; + }, + TransposeFolding::NeverFoldTranspose); pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); pipeline.AddPass( - module_config->mutable_entry_computation_layout()); + module->mutable_entry_computation_layout()); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }); + [](const Shape&, const Shape&) { return true; }, + /*enable_dot_simplification=*/false); pipeline.AddPass(/*is_layout_sensitive=*/true); // Outline ops in the entry computation into calls to subcomputations. - legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); - if (flags->xla_cpu_parallel) { + if (CpuParallelBackendRequested(module->config())) { pipeline.AddPass(); } - // Copy insertion should be performed immediately before IR emission to - // avoid inserting unnecessary copies (later pass adds an instruction which - // materializes the value) or missing a necessary copy (later pass removes - // an instruction which materializes a value). + // Copy insertion should be performed immediately before IR emission to avoid + // inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes an + // instruction which materializes a value). DCE must be run immediately before + // (and sometime after) copy insertion, to avoid dead code from interfering + // with the rewrites. + pipeline.AddPass(); pipeline.AddPass(); - if (flags->xla_cpu_parallel) { + if (CpuParallelBackendRequested(module->config())) { // Re-run the outlining, in case any copies were inserted into the entry // computation. pipeline.AddPass(); } pipeline.AddPass(); - return pipeline.Run(hlo_module).status(); + pipeline.AddPass(); + return pipeline.Run(module).status(); } namespace { +// Align buffers to 16-byte boundaries. +constexpr int64 kMemoryAlignment = 16; + llvm::TargetOptions CompilerTargetOptions( - const HloModuleConfig& execution_options) { + const HloModuleConfig& module_config) { llvm::TargetOptions target_options; - llvm_ir::SetTargetOptions(execution_options, &target_options); + llvm_ir::SetTargetOptions( + /*fast_math_enabled=*/module_config.debug_options() + .xla_enable_fast_math(), + &target_options); return target_options; } -llvm::CodeGenOpt::Level CodeGenOptLevel() { - legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); - switch (flags->xla_cpu_llvm_opt_level) { +llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) { + VLOG(2) << "backend_optimization_level: " + << module_config.debug_options().xla_backend_optimization_level(); + switch (module_config.debug_options().xla_backend_optimization_level()) { case 1: return llvm::CodeGenOpt::Less; case 2: @@ -282,28 +341,26 @@ llvm::CodeGenOpt::Level CodeGenOptLevel() { } // namespace StatusOr> CpuCompiler::Compile( - std::unique_ptr hlo_module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); + std::call_once(llvm_command_line_options_initialized, + &InitializeLLVMCommandLineOptions, module->config()); // Compile must be thread-safe so create a new LLVM context for the module. auto llvm_context = MakeUnique(); auto llvm_module = MakeUnique("__compute_module", *llvm_context); - auto jit = MakeUnique(CompilerTargetOptions(*module_config), - CodeGenOptLevel()); + auto jit = MakeUnique(CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config())); llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - const llvm::DataLayout& data_layout = llvm_module->getDataLayout(); - int64 pointer_size = data_layout.getPointerSize(); - TF_RETURN_IF_ERROR( - RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo)); + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), dump_hlo)); - HloComputation* computation = hlo_module->entry_computation(); + HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; - if (module_config->hlo_profiling_enabled()) { + if (module->config().hlo_profiling_enabled()) { TF_ASSIGN_OR_RETURN( hlo_to_profile_idx, CollectProfileCandidates::GetCandidatesForComputation(computation)); @@ -311,7 +368,7 @@ StatusOr> CpuCompiler::Compile( std::unique_ptr cpu_executable; legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); - if (flags->xla_cpu_parallel) { + if (CpuParallelBackendRequested(module->config())) { // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. // DependencyHloOrdering is used for the parallel emitter because the order @@ -320,9 +377,15 @@ StatusOr> CpuCompiler::Compile( // uses data dependencies for determining order. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(hlo_module.get(), - MakeUnique(hlo_module.get()), - pointer_size)); + BufferAssigner::Run(module.get(), + MakeUnique(module.get()), + BufferSizeBytesFunction(), kMemoryAlignment)); + + if (!flags->xla_cpu_dump_debug_json_to.empty()) { + HloProto proto = MakeHloProto(*module, *assignment); + TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( + proto, flags->xla_cpu_dump_debug_json_to, module->name())); + } // If we are using the parallel CPU backend, we need to create map from // HloInstruction to the corresponding generated function name. @@ -338,7 +401,7 @@ StatusOr> CpuCompiler::Compile( // Copy the constant out of the ProtocolBuffer so that we can give it a // higher alignment. const void* data = LiteralUtil::InternalData(instruction->literal()); - int64 size = llvm_ir::ByteSizeOf(instruction->shape(), data_layout); + int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( instruction, MakeUnique(size)); CHECK_EQ(iter.second, true); @@ -348,13 +411,14 @@ StatusOr> CpuCompiler::Compile( } // The parallel preparation should have ensured that the top-level // computation consists solely of Call instructions. - TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall); + TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall) + << module->ToString(); HloComputation* to_apply = instruction->to_apply(); parallel_computations.emplace(to_apply, instruction); } - IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, - llvm_module.get(), &hlo_to_profile_idx); + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + &hlo_to_profile_idx); std::unique_ptr> function_names( new std::map()); for (auto embedded_computation : @@ -369,7 +433,8 @@ StatusOr> CpuCompiler::Compile( llvm::Function * ir_function, ir_emitter.EmitComputation( embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/computation_is_parallel)); + /*is_entry_computation=*/computation_is_parallel, + /*instruction_order=*/nullptr)); // If this computation is parallel, remember it in the function name map. // This way we know what function to execute when we try to run code for // the Call instruction. @@ -388,9 +453,9 @@ StatusOr> CpuCompiler::Compile( // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new ParallelCpuExecutable( - std::move(jit), std::move(assignment), std::move(hlo_module), - std::move(module_config), std::move(function_names), - std::move(hlo_to_profile_idx), std::move(aligned_constants))); + std::move(jit), std::move(assignment), std::move(module), + std::move(function_names), std::move(hlo_to_profile_idx), + std::move(aligned_constants))); if (flags->xla_cpu_embed_ir) { static_cast(*cpu_executable) @@ -402,26 +467,29 @@ StatusOr> CpuCompiler::Compile( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence( - *hlo_module, [&](const LogicalBuffer& buffer) { - return llvm_ir::ByteSizeOf(buffer.shape(), data_layout); - })); + CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(hlo_module.get(), - MakeUnique(hlo_module.get(), - module_sequence), - pointer_size)); + BufferAssigner::Run( + module.get(), + MakeUnique(module.get(), module_sequence), + BufferSizeBytesFunction(), kMemoryAlignment)); + + if (!flags->xla_cpu_dump_debug_json_to.empty()) { + HloProto proto = MakeHloProto(*module, *assignment); + TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( + proto, flags->xla_cpu_dump_debug_json_to, module->name())); + } // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. - IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, - llvm_module.get(), &hlo_to_profile_idx); + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + &hlo_to_profile_idx); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { TF_RETURN_IF_ERROR( @@ -448,10 +516,9 @@ StatusOr> CpuCompiler::Compile( // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); - cpu_executable.reset( - new CpuExecutable(std::move(jit), std::move(assignment), - std::move(hlo_module), std::move(module_config), - function_name, std::move(hlo_to_profile_idx))); + cpu_executable.reset(new CpuExecutable( + std::move(jit), std::move(assignment), std::move(module), function_name, + std::move(hlo_to_profile_idx))); if (flags->xla_cpu_embed_ir) { static_cast(*cpu_executable) @@ -463,30 +530,31 @@ StatusOr> CpuCompiler::Compile( } StatusOr>> CpuCompiler::Compile( - std::vector> hlo_modules, - std::vector> module_configs, - HloDumper dump_hlos, std::vector stream_execs) { + std::vector> modules, HloDumper dump_hlos, + std::vector stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on CPU."); } StatusOr>> -CpuCompiler::CompileAheadOfTime( - std::vector> hlo_modules, - std::vector> module_configs, - HloDumper dump_hlo, const AotCompilationOptions& aot_options) { - TF_RET_CHECK(hlo_modules.size() == module_configs.size()); - TF_RET_CHECK(!hlo_modules.empty()); +CpuCompiler::CompileAheadOfTime(std::vector> modules, + HloDumper dump_hlo, + const AotCompilationOptions& aot_options) { + TF_RET_CHECK(!modules.empty()); + std::call_once(llvm_command_line_options_initialized, + &InitializeLLVMCommandLineOptions, modules[0]->config()); // We can pass just one llvm::TargetOptions when we compile the LLVM module, // so we bail if the configs have conflicting flags. At the moment, the only // flag that needs to be consistent is fast-math. - bool fast_math_disabled = module_configs[0]->fast_math_disabled(); - for (const auto& module_config : module_configs) { - if (module_config->fast_math_disabled() != fast_math_disabled) { + const bool fast_math_enabled = + modules[0]->config().debug_options().xla_enable_fast_math(); + for (const auto& module : modules) { + if (module->config().debug_options().xla_enable_fast_math() != + fast_math_enabled) { return InvalidArgument( "All HLO module configs must have the same value for " - "fast_math_disabled."); + "xla_enable_fast_math."); } } @@ -505,9 +573,9 @@ CpuCompiler::CompileAheadOfTime( error.c_str()); } - llvm::Reloc::Model reloc_model; - llvm::PICLevel::Level pic_level; - llvm::PIELevel::Level pie_level; + llvm::Reloc::Model reloc_model = llvm::Reloc::Static; + llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC; + llvm::PIELevel::Level pie_level = llvm::PIELevel::Default; switch (options.relocation_model()) { case CpuAotCompilationOptions::RelocationModel::Static: reloc_model = llvm::Reloc::Static; @@ -537,11 +605,11 @@ CpuCompiler::CompileAheadOfTime( } llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); llvm::StringRef features = llvm_ir::AsStringRef(options.features()); - llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(); + llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config()); std::unique_ptr target_machine = WrapUnique(target->createTargetMachine( triple.getTriple(), cpu_name, features, - CompilerTargetOptions(*module_configs[0]), reloc_model, + CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::CodeModel::Default, opt_level)); // Compile must be thread-safe so create a new LLVM context for the module. @@ -555,34 +623,35 @@ CpuCompiler::CompileAheadOfTime( if (pie_level != llvm::PIELevel::Default) { llvm_module.setPIELevel(pie_level); } - const llvm::DataLayout& data_layout = llvm_module.getDataLayout(); - int64 pointer_size = data_layout.getPointerSize(); std::vector> results; - for (int i = 0; i < hlo_modules.size(); ++i) { - HloModule* hlo_module = hlo_modules[i].get(); - HloModuleConfig* module_config = module_configs[i].get(); + for (size_t i = 0; i < modules.size(); ++i) { + HloModule* module = modules[i].get(); - TF_RETURN_IF_ERROR(RunHloPasses(hlo_module, module_config, dump_hlo)); + TF_RETURN_IF_ERROR(RunHloPasses(module, dump_hlo)); TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence( - *hlo_module, [&](const LogicalBuffer& buffer) { - return llvm_ir::ByteSizeOf(buffer.shape(), data_layout); - })); + CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(hlo_module, MakeUnique( - hlo_module, module_sequence), - pointer_size)); + BufferAssigner::Run( + module, MakeUnique(module, module_sequence), + BufferSizeBytesFunction(), kMemoryAlignment)); - IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module, + legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); + if (!flags->xla_cpu_dump_debug_json_to.empty()) { + HloProto proto = MakeHloProto(*module, *assignment); + TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( + proto, flags->xla_cpu_dump_debug_json_to, module->name())); + } + + IrEmitter ir_emitter(*module, *assignment, &llvm_module, /*hlo_to_profile_idx=*/nullptr); - HloComputation* computation = hlo_module->entry_computation(); + HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { TF_RETURN_IF_ERROR( @@ -597,7 +666,8 @@ CpuCompiler::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, ir_emitter.EmitComputation(computation, entry_point_name, - /*is_entry_computation=*/true)); + /*is_entry_computation=*/true, + &module_sequence.at(computation))); entry_function->setName(llvm_ir::AsStringRef(entry_point_name)); @@ -627,12 +697,12 @@ CpuCompiler::CompileAheadOfTime( buffer_sizes.push_back(allocation.size()); } - TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, - assignment->GetUniqueTopLevelOutputAllocation()); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment->GetUniqueTopLevelOutputSlice()); results.emplace_back(MakeUnique( std::move(object_file_data), std::move(buffer_sizes), - result_allocation->index())); + result_slice.index())); } return std::move(results); } @@ -641,11 +711,17 @@ se::Platform::Id CpuCompiler::PlatformId() const { return se::host::kHostPlatformId; } +HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { + return CpuExecutable::ShapeSizeBytes; +} + } // namespace cpu } // namespace xla -REGISTER_MODULE_INITIALIZER(cpu_compiler, { +static bool InitModule() { xla::Compiler::RegisterCompilerFactory(se::host::kHostPlatformId, []() { return xla::MakeUnique(); }); -}); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index d7d77ce58a6..29fa4eac61b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" @@ -33,8 +32,6 @@ namespace cpu { // This class wraps the configurability options that LLVM exposes including: the // target triple, the target cpu and the target features. It also includes the // desired linkage name for the computation entry point. -// Note that the optimization level can be controlled by the -// --xla_cpu_llvm_opt_level flag. class CpuAotCompilationOptions : public AotCompilationOptions { public: // Relocation models available for compilation. @@ -113,32 +110,29 @@ class CpuCompiler : public Compiler { ~CpuCompiler() override {} StatusOr> Compile( - std::unique_ptr hlo_module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( - std::vector> hlo_module, - std::vector> module_config, - HloDumper dump_hlo, + std::vector> modules, HloDumper dump_hlo, std::vector stream_exec) override; StatusOr>> - CompileAheadOfTime( - std::vector> module, - std::vector> module_config, - HloDumper dump_hlo, const AotCompilationOptions& options) override; + CompileAheadOfTime(std::vector> modules, + HloDumper dump_hlo, + const AotCompilationOptions& options) override; perftools::gputools::Platform::Id PlatformId() const override; + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; + private: // Initialize the LLVM target. static void InitializeLLVMTarget(); // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* hlo_module, HloModuleConfig* module_config, - HloDumper dump_hlo); + Status RunHloPasses(HloModule* hlo_module, HloDumper dump_hlo); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 727257d4f11..671d6957a39 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -53,11 +52,9 @@ namespace cpu { CpuExecutable::CpuExecutable( std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, - std::unique_ptr module_config, - const string& entry_function_name, + std::unique_ptr hlo_module, const string& entry_function_name, std::unordered_map hlo_to_profile_idx) - : Executable(std::move(hlo_module), std::move(module_config)), + : Executable(std::move(hlo_module), CpuExecutable::ShapeSizeBytes), jit_(std::move(jit)), assignment_(std::move(assignment)), hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) { @@ -135,10 +132,9 @@ Status CpuExecutable::AllocateBuffers( TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); } - TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, - assignment_->GetUniqueTopLevelOutputAllocation()); - - VLOG(3) << "result index: " << result_allocation->index(); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelOutputSlice()); + VLOG(3) << "result index: " << result_slice.index(); return Status::OK(); } @@ -193,9 +189,9 @@ Status CpuExecutable::ExecuteComputeFunction( for (auto& buffer : buffers) { buffer_pointers.push_back(const_cast(buffer.opaque())); } - TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, - assignment_->GetUniqueTopLevelOutputAllocation()); - void* result_buffer = buffer_pointers[result_allocation->index()]; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelOutputSlice()); + void* result_buffer = buffer_pointers[result_slice.index()]; if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; VLOG(3) << tensorflow::strings::Printf( @@ -231,7 +227,8 @@ Status CpuExecutable::ExecuteComputeFunction( } if (hlo_execution_profile != nullptr) { - hlo_execution_profile->set_total_cycles_executed(profile_counters.back()); + hlo_execution_profile->set_total_cycles_executed( + *module().entry_computation(), profile_counters.back()); for (auto hlo_prof_idx : hlo_to_profile_idx_) { const HloInstruction* hlo = hlo_prof_idx.first; @@ -243,24 +240,24 @@ Status CpuExecutable::ExecuteComputeFunction( } StatusOr CpuExecutable::ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); + TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - - TF_RETURN_IF_ERROR(ExecuteComputeFunction(run_options, arguments, buffers, - hlo_execution_profile)); + TF_RETURN_IF_ERROR(ExecuteComputeFunction( + &run_options->run_options(), arguments, buffers, hlo_execution_profile)); // Mark the buffers that are actually live (used in the output) when the // computation finishes executing. std::unordered_set marked_addresses; - TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, - assignment_->GetUniqueTopLevelOutputAllocation()); - se::DeviceMemoryBase top_level_output = buffers[result_allocation->index()]; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelOutputSlice()); + se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), &marked_addresses); @@ -275,10 +272,9 @@ StatusOr CpuExecutable::ExecuteOnStream( // Computation is done - deallocate temp buffers. Keep those marked live // because they are referenced by the output of the computation and are needed // by the service. They will be deallocated by the service. - for (auto i = 0; i < buffers.size(); ++i) { - auto alloc = buffers[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && - alloc.opaque() != nullptr) { + for (size_t i = 0; i < buffers.size(); ++i) { + se::DeviceMemoryBase alloc = buffers[i]; + if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) { VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" << alloc.opaque() << "]"; TF_RETURN_IF_ERROR(memory_allocator->Deallocate( @@ -290,37 +286,35 @@ StatusOr CpuExecutable::ExecuteOnStream( } StatusOr> CpuExecutable::ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); if (GetRootPointsToSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); } + + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_buffer, - ShapedBuffer::MakeShapedBuffer( - module_config().entry_computation_layout().result_shape(), - stream->parent()->platform(), stream->parent()->device_ordinal())); - + TF_ASSIGN_OR_RETURN(std::unique_ptr result_buffer, + ShapedBuffer::MakeShapedBuffer( + result_shape(), stream->parent()->platform(), + stream->parent()->device_ordinal())); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - - TF_RETURN_IF_ERROR(ExecuteComputeFunction(run_options, arguments, buffers, - hlo_execution_profile)); + TF_RETURN_IF_ERROR(ExecuteComputeFunction( + &run_options->run_options(), arguments, buffers, hlo_execution_profile)); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); TF_RETURN_IF_ERROR( result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElement( + ->ForEachMutableElementWithStatus( [&buffers, &buffers_in_result, &result_buffer, this]( - const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) { - if (is_leaf) { + const ShapeIndex& index, size_t* buffer_entry) { + if (ShapeUtil::IsLeafIndex(result_buffer->shape(), index)) { const std::vector& sources = this->GetRootPointsToSet().element(index); // The points to set is unambiguous so the set should be a @@ -334,24 +328,24 @@ StatusOr> CpuExecutable::ExecuteOnStream( // The source instruction should have a non-parameter buffer // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, - this->assignment_->GetUniqueAllocation( + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice( src, buffer_source->index())); - CHECK(!allocation->is_entry_computation_parameter()); + CHECK(!slice.allocation()->is_entry_computation_parameter()); - CHECK(!buffers[allocation->index()].is_null() || - buffers[allocation->index()].size() == 0); - result_buffer->mutable_buffers()->push_back( - buffers[allocation->index()]); - *buffer_entry = result_buffer->mutable_buffers()->size() - 1; - buffers_in_result[allocation->index()] = true; + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *buffer_entry = result_buffer->mutable_buffers()->size(); + result_buffer->mutable_buffers()->push_back(buffer); + buffers_in_result[buffer_index] = true; } return Status::OK(); })); // Free all buffers not in the result. - for (auto i = 0; i < buffers.size(); ++i) { - auto alloc = buffers[i]; + for (size_t i = 0; i < buffers.size(); ++i) { + se::DeviceMemoryBase alloc = buffers[i]; if (!buffers_in_result[i] && !alloc.is_null()) { VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" << alloc.opaque() << "]"; @@ -363,111 +357,23 @@ StatusOr> CpuExecutable::ExecuteOnStream( return std::move(result_buffer); } -Status CpuExecutable::ExecuteOnStream( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - // Every array element in the result of the computation must be unambiguously - // produced by a single instruction. - // This ensures that the buffers inside result_buffer can be assigned without - // conflict to the respective instructions because there is a one-to-one - // correspondence between hlo instructions and array buffers in the result. - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented( - "Points-to set of root instruction is ambiguous or not distinct"); - } - std::vector buffers(assignment_->Allocations().size()); - DCHECK(ShapeUtil::Compatible(result_buffer->shape(), result_shape())); - - // If two tuple elements point to the same buffer, one of the results in the - // result buffer is considered the canonical location while the other result - // points to it (instead of, say, making a copy of the result). - // buffer_index_to_shape_index maps a buffer index to its canonical location - // in the result buffer. - std::unordered_map - buffer_index_to_shape_index; - - // Copy values from result_buffer to the index in "buffers". These buffers - // will not be allocated in the call to AllocateBuffers. - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElement( - [&buffers, &buffers_in_result, &buffer_index_to_shape_index, - result_buffer, this](const ShapeIndex& index, bool is_leaf, - size_t* buffer_entry) { - if (is_leaf) { - const std::vector& sources = - this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, - this->assignment_->GetUniqueAllocation( - src, buffer_source->index())); - CHECK(!allocation->is_entry_computation_parameter()); - - auto insert_result = buffer_index_to_shape_index.emplace( - allocation->index(), *buffer_entry); - if (insert_result.second) { - // The points-to set is distinct so this buffer should not - // have - // been assigned in a previous invocation of this lambda. - perftools::gputools::DeviceMemoryBase memory_base = - result_buffer->buffer(index); - CHECK(buffers[allocation->index()].is_null()); - CHECK(!memory_base.is_null()); - buffers[allocation->index()] = memory_base; - buffers_in_result[allocation->index()] = true; - } else { - // Record the fact that this tuple element is identical to - // some - // prior result. - *buffer_entry = insert_result.first->second; - } - } - return Status::OK(); - })); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - - TF_RETURN_IF_ERROR(ExecuteComputeFunction(run_options, arguments, buffers, - hlo_execution_profile)); - - // Free all buffers not in the result. - for (auto i = 0; i < buffers.size(); ++i) { - auto alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); - } - } - - return Status::OK(); -} - StatusOr CpuExecutable::ExecuteAsyncOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on CPU."); } +/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { + // On the cpu, opaques are pointers. + if (ShapeUtil::IsOpaque(shape)) { + return sizeof(void*); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); +} + const PointsToSet& CpuExecutable::GetRootPointsToSet() const { return assignment_->points_to_analysis().GetPointsToSet( module().entry_computation()->root_instruction()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8f3247e6834..b5746769ba7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -52,30 +51,23 @@ class CpuExecutable : public Executable { std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, - std::unique_ptr module_config, const string& entry_function_name, std::unordered_map hlo_to_profile_idx); ~CpuExecutable() override {} StatusOr ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr> ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - Status ExecuteOnStream( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - ShapedBuffer* result_buffer, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; @@ -86,6 +78,8 @@ class CpuExecutable : public Executable { ir_module_string_ = ir_module_string; } + static int64 ShapeSizeBytes(const Shape& shape); + private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 240da35ef19..dc002846e9e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -24,6 +24,11 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Output fusion is not currently supported on CPUs. + if (producer->opcode() == HloOpcode::kFusion) { + return false; + } + // Condition for consumer: must be elementwise or a fusion op // (which necessarily only contains elementwise operations) if (!(consumer->opcode() == HloOpcode::kFusion || diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h index b7c646ad47d..0eca4c3473e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h @@ -24,8 +24,9 @@ namespace cpu { class CpuInstructionFusion : public InstructionFusion { public: - CpuInstructionFusion() {} - ~CpuInstructionFusion() override {} + CpuInstructionFusion() + : InstructionFusion(CpuInstructionFusion::IsExpensive) {} + ~CpuInstructionFusion() override = default; protected: bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 8e06f0520ed..253de20f251 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include #include #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 981f24ca6f5..7ad497ff1a2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -63,8 +63,8 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { IrEmitter::IrEmitter( - const HloModule& hlo_module, const HloModuleConfig& hlo_module_config, - const BufferAssignment& assignment, llvm::Module* llvm_module, + const HloModule& hlo_module, const BufferAssignment& assignment, + llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx) : assignment_(assignment), module_(llvm_module), @@ -72,8 +72,10 @@ IrEmitter::IrEmitter( ir_builder_(llvm_module->getContext()), hlo_to_profile_idx_(hlo_to_profile_idx), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), - hlo_module_config_(hlo_module_config) { - ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(hlo_module_config)); + hlo_module_config_(hlo_module.config()) { + ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( + /*fast_math_enabled=*/hlo_module_config_.debug_options() + .xla_enable_fast_math())); } StatusOr IrEmitter::EmitComputation( @@ -201,7 +203,8 @@ void IrEmitter::InitializeIrFunction(const string& function_name, if (&argument == retval) { continue; } - compute_function_->setDoesNotAlias(argument.getArgNo() + 1); + compute_function_->addAttribute(argument.getArgNo() + 1, + llvm::Attribute::NoAlias); } ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( @@ -506,7 +509,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window, llvm_ir::IrArray::Index input_index(index.size()); llvm::Value* in_bounds_condition = nullptr; - for (int64 i = 0; i < index.size(); ++i) { + for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( index[i], ir_builder_.getInt64(window.dimensions(i).stride())); input_index[i] = ir_builder_.CreateNSWSub( @@ -1111,7 +1114,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (int64 i = 0; i < input_index.size(); ++i) { + for (size_t i = 0; i < input_index.size(); ++i) { if (input_index[i] == nullptr) { input_index[i] = *it++; } @@ -1136,6 +1139,41 @@ Status IrEmitter::HandleSend(HloInstruction* send) { return Unimplemented("Send is not implemented on CPU. See b/33942983."); } +Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { + if (ShapeUtil::IsScalar(slice->shape())) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(slice)); + emitted_value_[slice] = target_address; + return EmitMemcpy(*operand, *slice); + } + return DefaultAction(slice); +} + +Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* /*start_indices*/) { + if (ShapeUtil::IsScalar(dynamic_slice->shape())) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(dynamic_slice)); + emitted_value_[dynamic_slice] = target_address; + return EmitMemcpy(*operand, *dynamic_slice); + } + return DefaultAction(dynamic_slice); +} + +Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* /*operand*/, + HloInstruction* update, + HloInstruction* /*start_indices*/) { + if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(dynamic_update_slice)); + emitted_value_[dynamic_update_slice] = target_address; + return EmitMemcpy(*update, *dynamic_update_slice); + } + return DefaultAction(dynamic_update_slice); +} + Status IrEmitter::HandleRecv(HloInstruction* recv) { // TODO(b/33942983): Support Send/Recv on CPU. return Unimplemented("Recv is not implemented on CPU. See b/33942983."); @@ -1180,7 +1218,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); llvm_ir::IrArray::Index output_index; - for (int64 i = 0; i < operand_index.size(); ++i) { + for (size_t i = 0; i < operand_index.size(); ++i) { llvm::Value* offset = ir_builder_.CreateMul( operand_index[i], ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() + @@ -1265,13 +1303,12 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { } } -Status IrEmitter::HandleCall( - HloInstruction* call, tensorflow::gtl::ArraySlice operands, - HloComputation* computation) { +Status IrEmitter::HandleCall(HloInstruction* call) { + HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); std::vector parameter_addresses; - for (HloInstruction* operand : operands) { + for (const HloInstruction* operand : call->operands()) { parameter_addresses.push_back(GetEmittedValueFor(operand)); } @@ -1294,12 +1331,12 @@ Status IrEmitter::HandleCustomCall( llvm_ir::EmitAllocaAtFunctionEntryWithCount( i8_ptr_type, ir_builder_.getInt32(operands.size()), "cc_operands_alloca", &ir_builder_); - for (int i = 0; i < operands.size(); ++i) { + for (size_t i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP( - operands_alloca, {ir_builder_.getInt32(i)}); + operands_alloca, {ir_builder_.getInt64(i)}); ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); } auto* custom_call_ir_function = @@ -1322,32 +1359,29 @@ Status IrEmitter::HandleCustomCall( return Status::OK(); } -Status IrEmitter::HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, HloComputation* body) { +Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Precondition: Condition computation must return a scalar bool. + HloComputation* condition = xla_while->while_condition(); TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; - // Check that all while-related buffers share an allocation. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape( + // Check that all while-related buffers share an allocation slice. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( xla_while->shape(), [this, &xla_while](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { auto check = [this](const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index) { - BufferAllocation::Index index_a = - assignment_.GetUniqueAllocation(a, index) - .ConsumeValueOrDie() - ->index(); - BufferAllocation::Index index_b = - assignment_.GetUniqueAllocation(b, index) - .ConsumeValueOrDie() - ->index(); - if (index_a != index_b) { + const BufferAllocation::Slice slice_a = + assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie(); + const BufferAllocation::Slice slice_b = + assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie(); + if (slice_a != slice_b) { return InternalError( - "instruction %s does not share allocation with " - "instruction %s ", - a->ToString().c_str(), b->ToString().c_str()); + "instruction %s %s does not share slice with " + "instruction %s %s", + a->ToString().c_str(), slice_a.ToString().c_str(), + b->ToString().c_str(), slice_b.ToString().c_str()); } return Status::OK(); }; @@ -1364,12 +1398,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while, HloInstruction* init, })); // Set emitted value to that of 'init' with which it shares an allocation. + const HloInstruction* init = xla_while->operand(0); emitted_value_[xla_while] = GetEmittedValueFor(init); // The called computation should have been emitted previously. llvm::Function* condition_ir_function = FindOrDie(emitted_functions_, condition); - llvm::Function* body_ir_function = FindOrDie(emitted_functions_, body); + llvm::Function* body_ir_function = + FindOrDie(emitted_functions_, xla_while->while_body()); // Generating: // while (Condition(while_result)) { @@ -1582,44 +1618,49 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { } llvm::Value* IrEmitter::EmitTempBufferPointer( - BufferAllocation::Index temp_buf_index, const Shape& target_shape) { + const BufferAllocation::Slice& slice, const Shape& target_shape) { llvm::Type* element_type = IrShapeType(target_shape); // The alignment and number of bytes within the temporary buffer is determined // by the maximal shape as determined by buffer assignment. - const BufferAllocation& allocation = - assignment_.GetAllocation(temp_buf_index); + const BufferAllocation& allocation = assignment_.GetAllocation(slice.index()); if (allocation.is_thread_local()) { // Thread-local allocations should only be assigned a single buffer. - CHECK_EQ(1, allocation.assigned_buffers().size()); - const Shape& shape = allocation.assigned_buffers()[0]->shape(); + const auto& assigned_buffers = allocation.assigned_buffers(); + CHECK_EQ(1, assigned_buffers.size()); + const Shape& shape = assigned_buffers.begin()->first->shape(); llvm::AllocaInst*& tempbuf_address = thread_local_buffers_[{ - ir_builder_.GetInsertBlock()->getParent(), temp_buf_index}]; + ir_builder_.GetInsertBlock()->getParent(), slice}]; if (tempbuf_address == nullptr) { tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( IrShapeType(shape), - tensorflow::strings::StrCat("thread_local", temp_buf_index), + tensorflow::strings::StrCat("thread_local", slice.ToString()), &ir_builder_, MinimumAlignmentForShape(target_shape)); } return ir_builder_.CreateBitCast(tempbuf_address, element_type->getPointerTo()); } - llvm::Value* tempbuf_address_offset = llvm_ir::EmitBufferIndexingGEP( - GetTempBuffersArgument(), temp_buf_index, &ir_builder_); - llvm::LoadInst* tempbuf_address_untyped = - ir_builder_.CreateLoad(tempbuf_address_offset); + llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( + GetTempBuffersArgument(), slice.index(), &ir_builder_); + llvm::LoadInst* tempbuf_address_base = + ir_builder_.CreateLoad(tempbuf_address_ptr); // Loading the address of a buffer is invariant of the point at which the // load is executed in the program because we never reassign buffers. - tempbuf_address_untyped->setMetadata( + tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, - llvm::MDNode::get(tempbuf_address_untyped->getContext(), /*MDs=*/{})); - llvm_ir::SetTbaaForInstruction(tempbuf_address_untyped, target_shape, + llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); + llvm_ir::SetTbaaForInstruction(tempbuf_address_base, target_shape, /*is_pointer_to=*/true); + AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); + AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size()); - AttachAlignmentMetadataForLoad(tempbuf_address_untyped, allocation.size()); - AttachDereferenceableMetadataForLoad(tempbuf_address_untyped, - allocation.size()); + llvm::Value* tempbuf_address_untyped = tempbuf_address_base; + if (slice.offset() > 0) { + // Adjust the address to account for the slice offset. + tempbuf_address_untyped = ir_builder_.CreateInBoundsGEP( + tempbuf_address_base, ir_builder_.getInt64(slice.offset())); + } return ir_builder_.CreateBitCast(tempbuf_address_untyped, element_type->getPointerTo()); } @@ -1657,13 +1698,13 @@ void IrEmitter::EmitArrayFunctionCallInto( ir_builder_.getInt32(parameter_addresses.size()), tensorflow::strings::StrCat(name, "_parameter_addresses"), &ir_builder_); - for (int i = 0; i < parameter_addresses.size(); ++i) { + for (size_t i = 0; i < parameter_addresses.size(); ++i) { llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( parameter_addresses[i], ir_builder_.getInt8PtrTy(), llvm_ir::AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, "_address_as_i8ptr"))); llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder_.getInt32(i)}); + parameter_addresses_buffer, {ir_builder_.getInt64(i)}); ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses); } @@ -1708,8 +1749,7 @@ StatusOr IrEmitter::EmitTargetAddressForOp( llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttr(llvm::AttributeSet::get( - retval->getContext(), retval->getArgNo() + 1, attr_builder)); + retval->addAttrs(attr_builder); } return ir_builder_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); @@ -1717,9 +1757,9 @@ StatusOr IrEmitter::EmitTargetAddressForOp( // For other nodes, we need the temporary buffer allocated for this node to // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, - assignment_.GetUniqueTopLevelAllocation(op)); - return EmitTempBufferPointer(allocation->index(), target_shape); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + return EmitTempBufferPointer(slice, target_shape); } Status IrEmitter::EmitTargetElementLoop( diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 9df5b8b3d25..ebb7296a075 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -60,24 +60,30 @@ class IrEmitter : public DfsHloVisitorWithDefault { // llvm_module: the LLVM module to emit IR into. // hlo_to_profile_idx: the mapping from HLO to its index in the profiling // array. - IrEmitter(const HloModule& hlo_module, const HloModuleConfig& module_config, - const BufferAssignment& assignment, llvm::Module* llvm_module, + IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, + llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR - // function. function_name_prefix is the desired name of the function. If the - // name is not unique among already emitted functions then a suffix is - // appended to make the name unique. is_entry_computation indicates that this - // is the entry computation of the HLO module. If 'instruction_order' is given - // then the HLO instructions are emitted in the given order. In this case, - // 'instruction_order' must be a topological sort of the set of nodes - // accessible from the root of the computation. + // function. + // + // function_name_prefix is the desired name of the function. If the name is + // not unique among already emitted functions then a suffix is appended to + // make the name unique. + // + // is_entry_computation indicates that this is the entry computation of the + // HLO module. + // + // If 'instruction_order' is not NULL, then the HLO instructions are emitted + // in the given order. In this case, 'instruction_order' must be a + // topological sort of the set of nodes accessible from the root of the + // computation. StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_entry_computation, - std::vector* instruction_order = nullptr); + std::vector* instruction_order); protected: // @@ -114,6 +120,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloComputation* function) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; Status HandleSend(HloInstruction* send) override; + Status HandleSlice(HloInstruction* slice, + HloInstruction* /*operand*/) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* /*operand*/, + HloInstruction* /*start_indices*/) override; + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* /*operand*/, + HloInstruction* /*update*/, + HloInstruction* /*start_indices*/) override; Status HandleRecv(HloInstruction* recv) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple( @@ -125,14 +140,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloComputation* function, tensorflow::gtl::ArraySlice static_operands) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call, - tensorflow::gtl::ArraySlice operands, - HloComputation* computation) override; + Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) override; - Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, HloComputation* body) override; + Status HandleWhile(HloInstruction* xla_while) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -183,7 +195,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. - llvm::Value* EmitTempBufferPointer(BufferAllocation::Index temp_buf_index, + llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); // Emits a function into the current module. This can be used for @@ -290,7 +302,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::map emitted_functions_; // Map containing all previously emitted thread-local temporary buffers. - std::map, + std::map, llvm::AllocaInst*> thread_local_buffers_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index ade7fa58a2b..bdddca99c2f 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -58,12 +57,11 @@ ParallelCpuExecutable::ParallelCpuExecutable( std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, - std::unique_ptr module_config, std::unique_ptr> function_names, std::unordered_map hlo_to_profile_idx, std::unordered_map> aligned_constants) - : Executable(std::move(hlo_module), std::move(module_config)), + : Executable(std::move(hlo_module), ParallelCpuExecutable::ShapeSizeBytes), jit_(std::move(jit)), assignment_(std::move(assignment)), functions_names_(std::move(function_names)), @@ -97,75 +95,81 @@ static void MarkLiveAddressesInOutput( } } -StatusOr -ParallelCpuExecutable::ExecuteOnStream( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - VLOG(3) << "ExecuteOnStream arg size: " << arguments.size(); - if (!arguments.empty()) { - VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque(); - } - - // Allocate the temporary buffers required for the computation. - se::StreamExecutor* stream_executor = stream->parent(); - int device_ordinal = stream_executor->device_ordinal(); - int64 buffer_count = assignment_->Allocations().size(); - VLOG(3) << "temp buffer count: " << buffer_count; - - std::vector device_allocations; - for (BufferAllocation::Index i = 0; i < buffer_count; ++i) { +Status ParallelCpuExecutable::AllocateBuffers( + DeviceMemoryAllocator* memory_allocator, int device_ordinal, + std::vector* buffers) { + CHECK_EQ(buffers->size(), assignment_->Allocations().size()); + VLOG(3) << "Allocating " << assignment_->Allocations().size() + << " allocations for module " << module().name(); + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { auto& allocation = assignment_->GetAllocation(i); + + VLOG(3) << allocation.ToString(); + if (allocation.is_entry_computation_parameter()) { - // Buffers do not need to be allocated for parameters. - device_allocations.push_back(se::DeviceMemoryBase(nullptr)); + VLOG(3) << "allocation #" << i << " is a parameter"; continue; } if (allocation.is_thread_local()) { - // Buffers do not need to be allocated for thread-local temporaries. - device_allocations.push_back(se::DeviceMemoryBase(nullptr)); + VLOG(3) << "buffer #" << i << " is thread-local"; continue; } - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase device_allocation, - memory_allocator->Allocate(device_ordinal, allocation.size())); + int64 buffer_size = allocation.size(); + if (!(*buffers)[i].is_null()) { + VLOG(3) << "buffer #" << i + << " is in the preallocated result ShapedBuffer"; + } else { + TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( + device_ordinal, buffer_size)); - if (VLOG_IS_ON(3)) { - VLOG(3) << "ParallelCpuExecutable allocating " << allocation.size() - << " bytes for allocation #" << i << " [" - << device_allocation.opaque() << "]"; - std::vector parts; - for (const LogicalBuffer* buffer : allocation.assigned_buffers()) { - parts.push_back(buffer->ToString()); - } - VLOG(3) << " " << tensorflow::str_util::Join(parts, ", "); + VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" + << (*buffers)[i].opaque() << "]"; } - device_allocations.push_back(device_allocation); // Since the output buffer and all the temporary buffers were written into // by the JITed code, msan has no way of knowing their memory was // initialized. Mark them initialized so that msan doesn't flag loads from // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(device_allocation.opaque(), - allocation.size()); + TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); } - TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, - assignment_->GetUniqueTopLevelOutputAllocation()); - BufferAllocation::Index result_index = result_allocation->index(); - VLOG(3) << "result index: " << result_index; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelOutputSlice()); + VLOG(3) << "result index: " << result_slice.index(); + return Status::OK(); +} + +Status ParallelCpuExecutable::ExecuteComputeFunctions( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice buffers, + HloExecutionProfile* hlo_execution_profile) { + std::vector argument_buffers(arguments.size()); + for (int i = 0; i < arguments.size(); ++i) { + TF_RET_CHECK(!ShapeUtil::IsTuple(arguments[i]->shape())); + argument_buffers[i] = arguments[i]->buffer(/*index=*/{}); + } + return ExecuteComputeFunctions(run_options, argument_buffers, buffers, + hlo_execution_profile); +} + +Status ParallelCpuExecutable::ExecuteComputeFunctions( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice buffers, + HloExecutionProfile* hlo_execution_profile) { // Allocate profiling counters for each hlo instruction that we would like to // profile. Allocate an additional profile counter for the entire // computation. std::vector profile_counters(hlo_to_profile_idx_.size() + 1); std::vector buffer_pointers; - for (auto& device_allocation : device_allocations) { + buffer_pointers.reserve(buffers.size()); + for (auto device_allocation : buffers) { buffer_pointers.push_back(device_allocation.opaque()); } @@ -188,8 +192,8 @@ ParallelCpuExecutable::ExecuteOnStream( std::list pending; // Call the function for each HLO instruction in topological order. - for (auto* instruction : - module().entry_computation()->MakeInstructionPostOrder()) { + const HloComputation& entry_computation = *module().entry_computation(); + for (auto* instruction : entry_computation.MakeInstructionPostOrder()) { // Parameters and constants have no functions associated with them. Instead // just copy the existing buffer into the map containing instruction // results.. @@ -206,9 +210,9 @@ ParallelCpuExecutable::ExecuteOnStream( } } - auto* temps_array = buffer_pointers.data(); - auto* profile_counters_array = profile_counters.data(); - auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool()); + void** temps_array = buffer_pointers.data(); + uint64* profile_counters_array = profile_counters.data(); + auto* thread_pool = CHECK_NOTNULL(run_options->xla_intra_op_thread_pool()); tensorflow::mutex completion_queue_lock; tensorflow::condition_variable completion_queue_cv; std::deque completion_queue; @@ -227,11 +231,11 @@ ParallelCpuExecutable::ExecuteOnStream( continue; } - TF_ASSIGN_OR_RETURN( - const BufferAllocation* result_allocation, - assignment_->GetUniqueTopLevelAllocation(instruction)); - - void* result_buffer = buffer_pointers[result_allocation->index()]; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelSlice(instruction)); + void* result_buffer = + static_cast(temps_array[result_slice.index()]) + + result_slice.offset(); // We cannot use a move-only RAII type like std::unique_ptr because the // list of operands is allocated on the main thread and transferred to the // worker via the lambda passed to enqueue_function. In order for the @@ -245,11 +249,12 @@ ParallelCpuExecutable::ExecuteOnStream( }); auto function = FindOrDie(functions, instruction); // The thread pool entry takes ownership of |operand_buffers|. + const auto* exec_run_options = &run_options->run_options(); thread_pool->Schedule([instruction, &completion_queue, &completion_queue_lock, &completion_queue_cv, - result_buffer, run_options, operand_buffers, + result_buffer, exec_run_options, operand_buffers, temps_array, profile_counters_array, function] { - function(result_buffer, run_options, operand_buffers, temps_array, + function(result_buffer, exec_run_options, operand_buffers, temps_array, profile_counters_array); delete[] operand_buffers; // Push the completed HLO instruction on the queue, the main thread @@ -279,9 +284,11 @@ ParallelCpuExecutable::ExecuteOnStream( break; } } while (1); - TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, - assignment_->GetUniqueTopLevelAllocation(instruction)); - void* result_buffer = buffer_pointers[result_allocation->index()]; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelSlice(instruction)); + void* result_buffer = + static_cast(temps_array[result_slice.index()]) + + result_slice.offset(); InsertOrDie(&results, instruction, result_buffer); --instructions_in_flight; } @@ -295,7 +302,8 @@ ParallelCpuExecutable::ExecuteOnStream( execution_profile_.set_compute_cycle_count(profile_counters.back()); } if (hlo_execution_profile != nullptr) { - hlo_execution_profile->set_total_cycles_executed(profile_counters.back()); + hlo_execution_profile->set_total_cycles_executed(entry_computation, + profile_counters.back()); for (auto hlo_prof_idx : hlo_to_profile_idx_) { const HloInstruction* hlo = hlo_prof_idx.first; @@ -304,6 +312,41 @@ ParallelCpuExecutable::ExecuteOnStream( } } + return Status::OK(); +} + +StatusOr +ParallelCpuExecutable::ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + VLOG(3) << "ExecuteOnStream arg size: " << arguments.size(); + if (!arguments.empty()) { + VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque(); + } + + // Allocate the temporary buffers required for the computation. + se::StreamExecutor* stream_executor = stream->parent(); + int device_ordinal = stream_executor->device_ordinal(); + int64 buffer_count = assignment_->Allocations().size(); + VLOG(3) << "temp buffer count: " << buffer_count; + + std::vector device_allocations( + assignment_->Allocations().size()); + TF_RETURN_IF_ERROR(AllocateBuffers(memory_allocator, + stream->parent()->device_ordinal(), + &device_allocations)); + + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelOutputSlice()); + const BufferAllocation::Index result_index = result_slice.index(); + VLOG(3) << "result index: " << result_index; + + TF_RETURN_IF_ERROR(ExecuteComputeFunctions( + run_options, arguments, device_allocations, hlo_execution_profile)); + // Mark the buffers that are actually live (used in the output) when the // computation finishes executing. std::unordered_set marked_addresses; @@ -322,7 +365,7 @@ ParallelCpuExecutable::ExecuteOnStream( // live because they are referenced by the output of the computation // and are needed by the service. They will be deallocated by the // service. - for (auto i = 0; i < device_allocations.size(); ++i) { + for (size_t i = 0; i < device_allocations.size(); ++i) { auto alloc = device_allocations[i]; if (marked_addresses.count(alloc.opaque()) == 0 && alloc.opaque() != nullptr) { @@ -336,29 +379,92 @@ ParallelCpuExecutable::ExecuteOnStream( } StatusOr> ParallelCpuExecutable::ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { - return Unimplemented( - "ParallelCpuExecutable not supported yet with LocalService execution"); -} + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } -Status ParallelCpuExecutable::ExecuteOnStream( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) { - return Unimplemented( - "preallocated result buffer not supported with ParallelCpuExecutable"); + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + std::vector buffers(assignment_->Allocations().size()); + + TF_ASSIGN_OR_RETURN(std::unique_ptr result_buffer, + ShapedBuffer::MakeShapedBuffer( + result_shape(), stream->parent()->platform(), + stream->parent()->device_ordinal())); + + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); + + TF_RETURN_IF_ERROR(ExecuteComputeFunctions(run_options, arguments, buffers, + hlo_execution_profile)); + + // Copy DeviceMemoryBase values which contain the array(s) of the result into + // the respective location in ShapedBuffer which is returned to the caller. + std::vector buffers_in_result(assignment_->Allocations().size(), false); + TF_RETURN_IF_ERROR( + result_buffer->mutable_shape_index_to_buffer_entry() + ->ForEachMutableElementWithStatus( + [&buffers, &buffers_in_result, &result_buffer, this]( + const ShapeIndex& index, size_t* buffer_entry) { + if (ShapeUtil::IsLeafIndex(result_buffer->shape(), index)) { + const std::vector& sources = + this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer + // such as a tuple element. + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice( + src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *buffer_entry = result_buffer->mutable_buffers()->size(); + result_buffer->mutable_buffers()->push_back(buffer); + buffers_in_result[buffer_index] = true; + } + return Status::OK(); + })); + + // Free all buffers not in the result. + for (size_t i = 0; i < buffers.size(); ++i) { + se::DeviceMemoryBase alloc = buffers[i]; + if (!buffers_in_result[i] && !alloc.is_null()) { + VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" + << alloc.opaque() << "]"; + TF_RETURN_IF_ERROR(memory_allocator->Deallocate( + stream->parent()->device_ordinal(), &alloc)); + } + } + + return std::move(result_buffer); } StatusOr ParallelCpuExecutable::ExecuteAsyncOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on CPU."); } +const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const { + return assignment_->points_to_analysis().GetPointsToSet( + module().entry_computation()->root_instruction()); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index 51ec9e5a741..6d5f790c394 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,7 +51,6 @@ class ParallelCpuExecutable : public Executable { std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, - std::unique_ptr module_config, std::unique_ptr> instruction_functions, std::unordered_map hlo_to_profile_idx, std::unordered_map ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr> ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - Status ExecuteOnStream( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - ShapedBuffer* result_buffer, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; @@ -89,7 +81,44 @@ class ParallelCpuExecutable : public Executable { ir_module_string_ = ir_module_string; } + static int64 ShapeSizeBytes(const Shape& shape) { + // On the cpu, opaques are pointers. + if (ShapeUtil::IsOpaque(shape)) { + return sizeof(void*); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + } + private: + // Allocate buffers required for execution and assign them to the elements of + // "buffers". "buffers" should be sized to the number of buffers in buffer + // assignment. Each vector element corresponds to a particular Index. If + // a vector element already contains a non-null DeviceMemoryBase, then no + // buffer is assigned for this element. + Status AllocateBuffers( + DeviceMemoryAllocator* memory_allocator, int device_ordinal, + std::vector* buffers); + + // Calls the generated functions in 'function_names_', performing the + // computation with the given arguments using the supplied buffers. + Status ExecuteComputeFunctions( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + arguments, + tensorflow::gtl::ArraySlice + buffers, + HloExecutionProfile* hlo_execution_profile); + Status ExecuteComputeFunctions( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice + buffers, + HloExecutionProfile* hlo_execution_profile); + + // Returns the points-to set of the root instruction of the entry + // computation. Uses points-to analysis from buffer assignment. + const PointsToSet& GetRootPointsToSet() const; + // The JIT containing compiled modules. tensorflow::mutex jit_mutex_; std::unique_ptr jit_ GUARDED_BY(jit_mutex_); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 677080a8623..ee772f5c396 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -54,7 +54,7 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; const Eigen::array dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + {DimPair(lhs_contract_dim, rhs_contract_dim)}); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 384a978873d..6f1c97a2334 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -48,7 +48,7 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; const Eigen::array dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + {DimPair(lhs_contract_dim, rhs_contract_dim)}); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 8beb565ab3e..7c74912a7ab 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -112,13 +112,25 @@ llvm::SmallVector DetectMachineAttributes() { if (llvm::sys::getHostCPUFeatures(host_features)) { for (auto &feature : host_features) { if (feature.second) { - result.push_back(feature.first()); + llvm::StringRef feature_name = feature.first(); + // Skip avx512 for now, it isn't quite ready in LLVM. + if (feature_name.startswith("avx512")) { + continue; + } + result.push_back(feature_name); } } } return result; } +llvm::StringRef GetHostCpuName() { + auto cpu_name = llvm::sys::getHostCPUName(); + // Skip avx512 for now, it isn't quite ready in LLVM. + cpu_name.consume_back("-avx512"); + return cpu_name; +} + CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { CompilerFunctor::VectorIntrinsics intrinsics; intrinsics.sse_intrinsics = (&runtime::ExpV4F32 != nullptr); @@ -136,13 +148,16 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options, .setOptLevel(opt_level) .selectTarget( /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", - /*MCPU=*/llvm::sys::getHostCPUName(), + /*MCPU=*/GetHostCpuName(), /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, - opt_level, GetAvailableIntrinsics())) {} + opt_level, GetAvailableIntrinsics())) { + VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() + << " features: " << target_machine_->getTargetFeatureString().str(); +} SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( std::unique_ptr module) { diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 9d1c842e0fb..4d8653484a0 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -22,7 +22,7 @@ limitations under the License. #include "external/llvm/include/llvm/ADT/Triple.h" #include "external/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h" -#include "external/llvm/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "external/llvm/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "external/llvm/include/llvm/IR/Module.h" #include "external/llvm/include/llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" @@ -41,7 +41,7 @@ namespace cpu { // it's added to the JIT. class SimpleOrcJIT { public: - using ObjLayerT = llvm::orc::ObjectLinkingLayer<>; + using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer<>; using CompileLayerT = llvm::orc::IRCompileLayer; using ModuleHandleT = CompileLayerT::ModuleSetHandleT; diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc index 423ec29fdc9..2d9d9c7de62 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc @@ -96,8 +96,8 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, } // namespace xla -static xla::TransferManager* CreateCpuTransferManager() { - return new xla::CpuTransferManager(); +static std::unique_ptr CreateCpuTransferManager() { + return xla::MakeUnique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index 1bef4e2b8c7..c13c86741cc 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -33,9 +33,6 @@ StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( StatusOr StreamExecutorMemoryAllocator::Allocate(int device_ordinal, uint64 size, bool retry_on_failure) { - if (size == 0) { - return perftools::gputools::DeviceMemoryBase(nullptr, 0); - } TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); return stream_executor->AllocateArray(size); @@ -74,4 +71,8 @@ StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) { return stream_executors_[device_ordinal]; } +bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const { + return false; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index 461cc818bff..391585a306d 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -51,6 +51,10 @@ class DeviceMemoryAllocator { // Return the platform that the allocator allocates memory on. const perftools::gputools::Platform* platform() const { return platform_; } + // Can we call Deallocate() as soon as a computation has been scheduled on + // a stream, or do we have to wait for the computation to complete first? + virtual bool AllowsAsynchronousDeallocation() const = 0; + protected: const perftools::gputools::Platform* platform_; }; @@ -69,6 +73,8 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { tensorflow::Status Deallocate( int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override; + bool AllowsAsynchronousDeallocation() const override; + private: StatusOr GetStreamExecutor( int device_ordinal); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index f9c9bbe2cdc..78a398f8efa 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -150,6 +151,10 @@ class DfsHloVisitor { virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand); } + virtual Status HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) { + return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite, operand); + } virtual Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, HloInstruction* rhs) { return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd, lhs, @@ -185,19 +190,16 @@ class DfsHloVisitor { virtual Status HandleTranspose(HloInstruction* transpose) = 0; virtual Status HandleParameter(HloInstruction* parameter) = 0; virtual Status HandleFusion(HloInstruction* fusion) = 0; - virtual Status HandleCall( - HloInstruction* call, - tensorflow::gtl::ArraySlice operands, - HloComputation* computation) = 0; + virtual Status HandleCall(HloInstruction* call) = 0; virtual Status HandleCustomCall( HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) = 0; virtual Status HandleSlice(HloInstruction* slice, HloInstruction* operand) = 0; - virtual Status HandleDynamicSlice( - HloInstruction* slice, - tensorflow::gtl::ArraySlice operands) = 0; + virtual Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* start_indices) = 0; virtual Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* operand, HloInstruction* update, @@ -215,9 +217,7 @@ class DfsHloVisitor { const Window& window, HloComputation* function) = 0; virtual Status HandleSelectAndScatter(HloInstruction* instruction) = 0; - virtual Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, - HloComputation* body) = 0; + virtual Status HandleWhile(HloInstruction* xla_while) = 0; virtual Status HandlePad(HloInstruction* pad) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 18cfaf83e1c..6557c3aa8e6 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" @@ -121,9 +122,7 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { Status HandleFusion(HloInstruction* fusion) override { return DefaultAction(fusion); } - Status HandleCall(HloInstruction* call, - tensorflow::gtl::ArraySlice /*operands*/, - HloComputation* /*computation*/) override { + Status HandleCall(HloInstruction* call) override { return DefaultAction(call); } Status HandleCustomCall( @@ -136,10 +135,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { HloInstruction* /*operand*/) override { return DefaultAction(slice); } - Status HandleDynamicSlice( - HloInstruction* slice, - tensorflow::gtl::ArraySlice /*operands*/) override { - return DefaultAction(slice); + Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* /*operand*/, + HloInstruction* /*start_indices*/) override { + return DefaultAction(dynamic_slice); } Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* /*operand*/, @@ -188,9 +187,7 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { Status HandleTranspose(HloInstruction* transpose) override { return DefaultAction(transpose); } - Status HandleWhile(HloInstruction* xla_while, HloInstruction* /*init*/, - HloComputation* /*condition*/, - HloComputation* /*body*/) override { + Status HandleWhile(HloInstruction* xla_while) override { return DefaultAction(xla_while); } Status HandleSend(HloInstruction* send) override { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 9dd276952cc..be4aadb6522 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -195,6 +195,19 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), llvm::ConstantFP::get(type, 1.0))); } + case HloOpcode::kIsFinite: { + // (x == x) && abs(x) != inf + auto type = operand_value->getType(); + auto equal_self = + ir_builder_->CreateFCmpOEQ(operand_value, operand_value); + auto abs_value = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); + auto infinity = llvm::ConstantFP::getInfinity(type); + auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); + auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); + return ir_builder_->CreateZExt( + result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); default: @@ -227,14 +240,18 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( return ir_builder_->CreateFDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: return ir_builder_->CreateFRem(lhs_value, rhs_value); - - // The 'O' prefix on the LLVM ops means "ordered" compare where comparisons - // with NAN always return false. + // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered + // comparisons always return false when one of the operands is NaN, whereas + // unordered comparisons return true. + // + // We use ordered comparisons for everything except kNe, where we use an + // unordered comparison. This makes x != y equivalent to !(x == y), and + // matches C++'s semantics. case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, rhs_value, ir_builder_); case HloOpcode::kNe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_ONE, lhs_value, + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, rhs_value, ir_builder_); case HloOpcode::kLt: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, @@ -428,8 +445,8 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, int64 operand_no) const { - CHECK(hlo.IsElementwise()) << "HLO " << hlo.ToString() - << " is not elementwise."; + CHECK(hlo.IsElementwise()) + << "HLO " << hlo.ToString() << " is not elementwise."; const Shape& operand_shape = hlo.operand(operand_no)->shape(); // If the operand is scalar, the source index is always {}. @@ -474,8 +491,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D})); auto random_value = [hlo]() { - CHECK(hlo->parent() != nullptr && hlo->parent()->parent() != nullptr); - const HloModule* module = hlo->parent()->parent(); + const HloModule* module = + hlo->IsFused() ? hlo->fusion_instruction()->parent()->parent() + : hlo->parent()->parent(); return module->RandomNew64(); }; @@ -631,6 +649,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCopy: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kNegate: case HloOpcode::kSign: @@ -724,11 +743,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(operand_idx); auto true_block = llvm_ir::CreateBasicBlock( exit_block, tensorflow::strings::StrCat( - "concat_index_from_operand", operand_idx), + "concat_index_from_operand", operand_idx), ir_builder_); auto false_block = llvm_ir::CreateBasicBlock( exit_block, tensorflow::strings::StrCat( - "concat_index_not_from_operand", operand_idx), + "concat_index_not_from_operand", operand_idx), ir_builder_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), @@ -788,9 +807,20 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const IrArray::Index& index) -> StatusOr { IrArray::Index sliced_index(index.size()); for (int i = 0; i < index.size(); ++i) { - sliced_index[i] = ir_builder_->CreateAdd( - index[i], llvm::ConstantInt::get(index[i]->getType(), - hlo->slice_starts(i))); + int64 stride = hlo->slice_stride(i); + if (stride != 1) { + sliced_index[i] = ir_builder_->CreateAdd( + ir_builder_->CreateMul( + index[i], llvm::ConstantInt::get(index[i]->getType(), + stride)), + llvm::ConstantInt::get(index[i]->getType(), + hlo->slice_starts(i))); + } else { + sliced_index[i] = ir_builder_->CreateAdd( + index[i], + llvm::ConstantInt::get(index[i]->getType(), + hlo->slice_starts(i))); + } } return operand_to_generator.at(hlo->operand(0))(sliced_index); }; @@ -922,6 +952,68 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( }; case HloOpcode::kRng: return MakeRngElementGenerator(hlo, operand_to_generator); + case HloOpcode::kPad: + return [=, &operand_to_generator]( + const IrArray::Index& padded_index) -> StatusOr { + auto index = padded_index; + llvm::Value* in_bounds = ir_builder_->getTrue(); + for (size_t i = 0; i < index.size(); ++i) { + auto index_typed_const = [=](int64 n) { + return llvm::ConstantInt::get(index[i]->getType(), n); + }; + const auto& pad_dim = hlo->padding_config().dimensions(i); + index[i] = ir_builder_->CreateSub( + index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), + "in_bounds"); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpEQ( + index_typed_const(0), + ir_builder_->CreateURem( + index[i], + index_typed_const(pad_dim.interior_padding() + 1))), + "in_bounds"); + index[i] = ir_builder_->CreateSDiv( + index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSLT( + index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); + } + + // if (in_bounds) { + // ret_value = operand0[index]; // source + // } else { + // ret_value = *operand1; // padding + // } + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), + ir_builder_), + "pad_result_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); + ir_builder_->CreateStore(operand_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, + operand_to_generator.at(hlo->operand(1))({})); + ir_builder_->CreateStore(padding_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + // Don't create phi(operand_value, padding_value) here, because invoking + // operand_to_generator may create new basic blocks, making the parent + // of operand_value or padding_value no longer a predecessor of + // if_data.after_block. + return ir_builder_->CreateLoad(ret_value_addr); + }; default: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { return Unimplemented("%s", HloOpcodeString(hlo->opcode()).c_str()); diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 5b1a5a16d1f..3a9f8dc79ee 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -16,16 +16,47 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/regexp.h" namespace xla { +/* static */ void Executable::DumpExecutedHlo( + const HloModule& module, const string& label, + const HloExecutionProfile* profile) { + VLOG(2) << "module name = " << module.name(); + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + string generate_hlo_graph_regex; + if (!flags->xla_generate_hlo_graph.empty()) { + generate_hlo_graph_regex = flags->xla_generate_hlo_graph; + } else { + generate_hlo_graph_regex = + module.config().debug_options().xla_generate_hlo_graph(); + } + if (!generate_hlo_graph_regex.empty() && + RE2::PartialMatch(module.name(), generate_hlo_graph_regex)) { + hlo_graph_dumper::DumpGraph(*module.entry_computation(), label, + flags->xla_hlo_graph_addresses, + flags->xla_hlo_graph_layout, profile); + } + if (!flags->xla_log_hlo_text.empty() && + RE2::PartialMatch(module.name(), flags->xla_log_hlo_text)) { + LOG(INFO) << "HLO for module " << module.name(); + LOG(INFO) << "Label: " << label; + XLA_LOG_LINES(2, module.ToString()); + } + if (!flags->xla_dump_hlo_text_to.empty()) { + hlo_graph_dumper::DumpText(module, label, flags->xla_dump_hlo_text_to); + } +} + StatusOr> Executable::ExecuteOnStreams( - tensorflow::gtl::ArraySlice run_options, + tensorflow::gtl::ArraySlice run_options, tensorflow::gtl::ArraySlice< tensorflow::gtl::ArraySlice> arguments) { @@ -40,7 +71,7 @@ Executable::ExecuteOnStreams( std::vector return_values( run_options.size()); - for (int64 i = 0; i < run_options.size(); ++i) { + for (size_t i = 0; i < run_options.size(); ++i) { // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched // executions may never complete if not all executions are running. @@ -68,13 +99,23 @@ Status Executable::DumpSessionModule() { *session_module_); } +// Removes illegal characters from filenames. +static void SanitizeFilename(string* name) { + for (char& c : *name) { + if (c == '/' || c == '\\' || c == '[' || c == ']') { + c = '_'; + } + } +} + /* static */ Status Executable::DumpToDirectory( - const string& directory_path, const string& filename, + const string& directory_path, string filename, const SessionModule& session_module) { tensorflow::Env* env = tensorflow::Env::Default(); if (!env->IsDirectory(directory_path).ok()) { TF_RETURN_IF_ERROR(env->CreateDir(directory_path)); } + SanitizeFilename(&filename); string file_path = tensorflow::io::JoinPath(directory_path, filename); return tensorflow::WriteBinaryProto(env, file_path, session_module); } diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index ac478afabc2..291916cd9f7 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -19,16 +19,18 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" @@ -39,17 +41,18 @@ namespace xla { // A given platform's compiler will produce an Executable -- this is a uniform // interface that is used for launching compiled programs across platforms. -// -// TODO(leary) will need to extend this to support multiple streams/devices as -// we begin to compile single programs to run on multiple devices. class Executable { public: explicit Executable(std::unique_ptr hlo_module, - std::unique_ptr module_config) + HloCostAnalysis::ShapeSizeFunction shape_size_function) : hlo_module_(std::move(hlo_module)), - module_config_(std::move(module_config)) {} + shape_size_function_(std::move(shape_size_function)) {} virtual ~Executable() {} + // Dumps the executed HLO according to service-associated flags. + static void DumpExecutedHlo(const HloModule& module, const string& label, + const HloExecutionProfile* profile); + // Enqueues the compilation result on the provided stream, passing the given // arguments. This call is blocking and returns after the execution is done. // @@ -59,7 +62,7 @@ class Executable { // Returns the device memory region that a successful execution would // populate. virtual StatusOr ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) = 0; @@ -67,22 +70,14 @@ class Executable { // Overload of ExecuteOnStream which returns and takes arguments as // ShapedBuffers. Used for LocalService execution. virtual StatusOr> ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) = 0; - // Overload of which writes the result into a pre-allocated buffer - // (result_buffer). - virtual Status ExecuteOnStream( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - ShapedBuffer* result_buffer, - HloExecutionProfile* hlo_execution_profile) = 0; - // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. virtual StatusOr ExecuteAsyncOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) = 0; @@ -92,11 +87,22 @@ class Executable { // returned vector. virtual StatusOr> ExecuteOnStreams( - tensorflow::gtl::ArraySlice run_options, + tensorflow::gtl::ArraySlice + run_options, tensorflow::gtl::ArraySlice< tensorflow::gtl::ArraySlice> arguments); + // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a + // timer for the execution, sets up HLO profiling if enabled, and fills in the + // given ExecutionProfile if non-null. The ExecuteOnStream overloads have + // different argument types and return types, so this method is templated on + // argument type and return type of the execute function. + template + StatusOr ExecuteOnStreamWrapper( + const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, + const ArgT& arguments); + // Returns the ExecutionProfile from executing on the device. This includes // the number of cycles taken for the computation or the compilation time. ExecutionProfile execution_profile() const { @@ -108,15 +114,14 @@ class Executable { // enabled. If not, the caller should not expect an hlo_execution_profile // passed to ExecuteOnStream above to be populated during execution. bool hlo_profiling_enabled() const { - return module_config_->hlo_profiling_enabled(); + return hlo_module_->config().hlo_profiling_enabled(); } const HloModule& module() const { return *hlo_module_; } - const HloModuleConfig& module_config() const { return *module_config_; } + const bool has_module() const { return hlo_module_ != nullptr; } - // Returns whether this executable has an associated HloModuleConfig. - bool has_module_config() const { return module_config_ != nullptr; } + const HloModuleConfig& module_config() const { return hlo_module_->config(); } // Returns the versioned computation handle of the computation computed by // this executable. @@ -127,7 +132,7 @@ class Executable { // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. const Shape& result_shape() const { - return module_config_->entry_computation_layout().result_shape(); + return hlo_module_->config().entry_computation_layout().result_shape(); } // Dumping helpers. @@ -139,10 +144,14 @@ class Executable { Status DumpSessionModule(); // Dump session_module to directory_path/filename. - static Status DumpToDirectory(const string& directory_path, - const string& filename, + static Status DumpToDirectory(const string& directory_path, string filename, const SessionModule& session_module); + // Return a reference to a function that computes the size of a given Shape. + const HloCostAnalysis::ShapeSizeFunction& shape_size_function() const { + return shape_size_function_; + } + protected: mutable tensorflow::mutex mutex_; @@ -154,9 +163,10 @@ class Executable { // around. std::unique_ptr hlo_module_; - // The configuration used to build this executable (parameter layouts, result - // layout, profiling enabled, etc). - std::unique_ptr module_config_; + // Function to compute the size of a given Shape, in bytes. This is + // provided to the Executable when it is constructed, and used to produce + // data for profiling the execution. + HloCostAnalysis::ShapeSizeFunction shape_size_function_; // SessionModule this was compiled from. Null if not dumping executions. std::unique_ptr session_module_; @@ -166,6 +176,76 @@ class Executable { int64 execution_count_ = 0; }; +template +StatusOr Executable::ExecuteOnStreamWrapper( + const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, + const ArgT& arguments) { + perftools::gputools::Stream* stream = run_options->stream(); + std::unique_ptr timer; + if (profile != nullptr) { + timer.reset(new perftools::gputools::Timer(stream->parent())); + stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); + } + + VLOG(1) << "enqueueing executable on stream..."; + // If the profiling flag isn't enabled, we pass nullptr as the profile to + // indicate profiling is not requested. + HloExecutionProfile hlo_execution_profile; + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + HloExecutionProfile* profile_ptr = + flags->xla_hlo_profile && hlo_profiling_enabled() ? &hlo_execution_profile + : nullptr; + + auto return_value = ExecuteOnStream(run_options, arguments, profile_ptr); + + if (profile != nullptr) { + VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; + stream->ThenStopTimer(timer.get()).BlockHostUntilDone(); + VLOG(1) << "done with block-host-until-done"; + + // Merge in run-time profile information from execution_profile. + profile->MergeFrom(execution_profile()); + + // Overall execution time (in nanoseconds) from the executor timer. + profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); + + // TODO(b/28123297): On GPU we end up including transfer time in + // the compute time this way. Instead, we should get the correct + // value by measuring it. Setting the field here at least lets + // benchmarks provide *some* value for GPU computations. + // + // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually + // the compute time without the transfer time, so this way we get the + // correct compute time. We should instead have the correct value for + // compute_and_transfer_time and set compute_time to the compute time. + if (profile->compute_time_ns() == 0) { + profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); + } + } + + if (profile_ptr != nullptr) { + std::unordered_set profiled_computations = + profile_ptr->profiled_computations(); + // To ensure we have print the profiles in a stable order, iterate over the + // computations in post order. + std::list all_computations = + module().MakeComputationPostOrder(); + for (xla::HloComputation* computation : all_computations) { + if (profiled_computations.count(computation) > 0) { + string profile_string = profile_ptr->ToString( + *computation, stream->parent()->GetDeviceDescription(), + shape_size_function_); + if (!profile_string.empty()) { + XLA_LOG_LINES(tensorflow::INFO, profile_string); + } + } + } + DumpExecutedHlo(module(), "Service::Execute", profile_ptr); + } + + return return_value; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index cf1870580c4..c225e62e3e1 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -24,25 +24,19 @@ limitations under the License. namespace xla { -AsyncExecution::AsyncExecution( - Backend* backend, - std::vector> streams, - const ExecutionProfile& profile, GlobalDataHandle result) +AsyncExecution::AsyncExecution(Backend* backend, + std::vector streams, + const ExecutionProfile& profile, + GlobalDataHandle result) : backend_(CHECK_NOTNULL(backend)), streams_(std::move(streams)), profile_(profile), - result_(result) { + result_(std::move(result)) { for (const auto& stream : streams_) { CHECK(stream != nullptr); } } -AsyncExecution::~AsyncExecution() { - for (auto& stream : streams_) { - backend_->ReleaseStream(std::move(stream)); - } -} - tensorflow::Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { if (!stream->BlockHostUntilDone()) { @@ -55,8 +49,7 @@ tensorflow::Status AsyncExecution::BlockUntilDone() const { ExecutionTracker::ExecutionTracker() : next_handle_(1) {} ExecutionHandle ExecutionTracker::Register( - Backend* backend, - std::vector> streams, + Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result) { tensorflow::mutex_lock lock(execution_mutex_); int64 handle = next_handle_++; diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 99a5bb5ad99..5b6bddf9f16 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/pool.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -39,12 +40,9 @@ namespace xla { // the stream when destructed. class AsyncExecution { public: - AsyncExecution( - Backend* backend, - std::vector> streams, - const ExecutionProfile& profile, GlobalDataHandle result); + AsyncExecution(Backend* backend, std::vector streams, + const ExecutionProfile& profile, GlobalDataHandle result); - ~AsyncExecution(); tensorflow::Status BlockUntilDone() const; const GlobalDataHandle& result() const { return result_; } @@ -56,7 +54,7 @@ class AsyncExecution { Backend* backend_; // Stream on which the execution is launched. - std::vector> streams_; + std::vector streams_; // Profile object of the execution to be returned to the user. ExecutionProfile profile_; @@ -73,10 +71,10 @@ class ExecutionTracker { // Registers an execution with its backend, streams, and data handle to the // execution result. Returns a handle for the registered execution. - ExecutionHandle Register( - Backend* backend, - std::vector> stream, - const ExecutionProfile& profile, GlobalDataHandle data); + ExecutionHandle Register(Backend* backend, + std::vector stream, + const ExecutionProfile& profile, + GlobalDataHandle data); // Unregisters the execution for the given handle. tensorflow::Status Unregister(const ExecutionHandle& handle); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc new file mode 100644 index 00000000000..297a4f7599f --- /dev/null +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// Helper to replace the called computation at a while- or call-instruction. +void ReplaceCalledComputation(HloInstruction* instruction, + HloComputation* computation, + HloComputation* new_computation) { + switch (instruction->opcode()) { + case HloOpcode::kWhile: { + if (computation == instruction->while_condition()) { + instruction->set_while_condition(new_computation); + } else { + CHECK_EQ(computation, instruction->while_body()); + instruction->set_while_body(new_computation); + } + break; + } + case HloOpcode::kCall: { + CHECK_EQ(instruction->to_apply(), computation); + instruction->set_to_apply(new_computation); + break; + } + default: + LOG(FATAL) << "unexpected opcode: " + << HloOpcodeString(instruction->opcode()); + } +} + +// Flatten a single call graph node. Expects to visit nodes in postorder. +Status FlattenNode(const CallGraphNode& node) { + HloComputation* computation = node.computation(); + HloModule* module = computation->parent(); + // Clone callee for all call-sites except the first one. + for (int i = 0; i < node.caller_callsites().size(); ++i) { + CallSite call_site = node.caller_callsites()[i]; + // Only consider sequential call contexts. + if (call_site.context() == CallContext::kParallel) { + continue; + } + CHECK_EQ(call_site.context(), CallContext::kSequential); + + // Skip first element if this computation is only called from a sequential + // context. + if (node.context() != CallContext::kBoth && i == 0) { + continue; + } + + // Clone computation for the remaining sequential context call sites. + HloComputation* clone = + module->AddEmbeddedComputation(computation->Clone()); + ReplaceCalledComputation(call_site.instruction(), computation, clone); + // Clone the sub-tree of all computations called from this node. + std::vector worklist; + worklist.push_back(clone); + while (!worklist.empty()) { + auto current = worklist.back(); + worklist.pop_back(); + for (auto& instruction : current->instructions()) { + if (GetInstructionCallContext(instruction.get()) != + CallContext::kSequential) { + continue; + } + for (auto callee : instruction->called_computations()) { + HloComputation* callee_clone = + module->AddEmbeddedComputation(callee->Clone()); + ReplaceCalledComputation(instruction.get(), callee, callee_clone); + worklist.push_back(callee_clone); + } + } + } + } + return Status::OK(); +} + +} // namespace + +StatusOr FlattenCallGraph::Run(HloModule* module) { + XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString()); + + std::unique_ptr call_graph = CallGraph::Build(module); + TF_RETURN_IF_ERROR(call_graph->VisitNodes(FlattenNode)); + + XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString()); + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h new file mode 100644 index 00000000000..d3efab36149 --- /dev/null +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Flatten the call graph for an HLO module into a tree. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Flattening associates each call site with a unique computation (for +// sequential calling contexts) This simplifies buffer assignment and +// points-to analysis (see b/36865746 for details). +class FlattenCallGraph : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "flatten-call-graph"; } + + // Duplicates computations called from multiple call- or while-nodes to + // flatten the call graph. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc new file mode 100644 index 00000000000..bb4712c86f6 --- /dev/null +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -0,0 +1,231 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class FlattenCallGraphTest : public HloTestBase { + protected: + // Build and return a trivial computation taking and returning a scalar. + std::unique_ptr MakeScalarComputation() { + HloComputation::Builder builder(TestName() + ".ScalarComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); + return builder.Build(); + } + + // Build and return a computation which takes a scalar and maps (kMap) the + // given computation to the value 'callsites' number of times. + std::unique_ptr MakeMappingComputation( + HloComputation* map_computation, int64 callsites) { + HloComputation::Builder builder(TestName() + ".MappingComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* last_value = param0; + for (int64 i = 0; i < callsites; ++i) { + last_value = builder.AddInstruction(HloInstruction::CreateMap( + kScalarShape, {last_value}, map_computation)); + } + return builder.Build(); + } + + // Build and return a computation which takes a scalar and calls (kCall) the + // given computation with value 'callsites' number of times. + std::unique_ptr MakeCallingComputation( + HloComputation* callee_computation, int64 callsites, + const string& suffix = ".CallingComputation") { + HloComputation::Builder builder(TestName() + suffix); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* last_value = param0; + for (int64 i = 0; i < callsites; ++i) { + last_value = builder.AddInstruction(HloInstruction::CreateCall( + kScalarShape, {last_value}, callee_computation)); + } + return builder.Build(); + } + + // Build and return a computation which takes a scalar and returns a PRED + // value. + std::unique_ptr MakeConditionComputation() { + HloComputation::Builder builder(TestName() + ".ConditionComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + return builder.Build(); + } + + StatusOr RunFlattenCallGraph(HloModule* module) { + FlattenCallGraph flatten; + TF_ASSIGN_OR_RETURN(bool result, flatten.Run(module)); + return result; + } + + const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(FlattenCallGraphTest, ComplexGraph) { + // Test a call graph of a module with several computation called in various + // contexts. The call graph looks like: + // + // entry + // / | + // a | + // / | \ | + // b | cond + // \ | + // c + // + // Calls are made via kCall, kWhile, and kMap instructions. + auto module = CreateNewModule(); + HloComputation* cond_computation = + module->AddEmbeddedComputation(MakeConditionComputation()); + HloComputation* c_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + HloComputation* b_computation = module->AddEmbeddedComputation( + MakeMappingComputation(c_computation, /*callsites=*/1)); + + HloComputation* a_computation; + { + HloComputation::Builder builder(TestName() + ".a"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(kScalarShape, {param0}, c_computation)); + builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, b_computation, call)); + a_computation = module->AddEmbeddedComputation(builder.Build()); + } + + HloComputation* entry_computation; + { + HloComputation::Builder builder(TestName() + ".entry"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, a_computation, param0)); + entry_computation = module->AddEntryComputation(builder.Build()); + } + + { + TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get())); + EXPECT_TRUE(result); + std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); + const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); + EXPECT_EQ(1, c_node.caller_callsites().size()); + } +} + +// Test corner case of a computation used as a body and a loop condition. +TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { + auto module = CreateNewModule(); + HloComputation* cond_computation; + { + HloComputation::Builder builder(TestName() + ".cond"); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(PRED, {}), "param0")); + HloInstruction* false_constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kEq, param0, false_constant)); + cond_computation = module->AddEmbeddedComputation(builder.Build()); + } + + HloComputation* entry_computation; + { + HloComputation::Builder builder(TestName() + ".entry"); + HloInstruction* false_constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateWhile( + ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation, + false_constant)); + entry_computation = module->AddEntryComputation(builder.Build()); + } + + { + std::unique_ptr call_graph = CallGraph::Build(module.get()); + const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); + EXPECT_EQ(2, cond_node.caller_callsites().size()); + } + + { + TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get())); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); + EXPECT_EQ(1, cond_node.caller_callsites().size()); + } +} + +// Test flattening of a nested calling computations. +// +// Entry +// / \ +// \ / +// B +// / \ +// \ / +// C +// +TEST_F(FlattenCallGraphTest, FlattenCalls) { + auto module = CreateNewModule(); + HloComputation* c_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + + HloComputation* b_computation = module->AddEmbeddedComputation( + MakeCallingComputation(c_computation, /*callsites=*/2, ".B")); + + module->AddEntryComputation( + MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); + + TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(module.get())); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + EXPECT_EQ(7, module->computations().size()); + + const CallGraphNode& c_node = call_graph->GetNode(c_computation); + EXPECT_EQ(1, c_node.caller_callsites().size()); + + const CallGraphNode& b_node = call_graph->GetNode(b_computation); + EXPECT_EQ(1, b_node.caller_callsites().size()); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 8f39ba8b1d2..eb8b93330fb 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -118,10 +118,10 @@ GenericTransferManager::ShallowCopyTupleFromDevice( // Create a DeviceMemoryBase from each void* pointer. std::vector destination; - for (int i = 0; i < element_pointers.size(); ++i) { + for (size_t i = 0; i < element_pointers.size(); ++i) { if (element_pointers[i] == nullptr && !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { - return FailedPrecondition("tuple contains nullptr at element %d", i); + return FailedPrecondition("tuple contains nullptr at element %lu", i); } int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), /*pointer_size=*/sizeof(void*)); @@ -162,6 +162,12 @@ Status GenericTransferManager::TransferLiteralToInfeed( return Unimplemented("Infeed is not supported on GPU (b/30467474)"); } +Status GenericTransferManager::TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) { + return Unimplemented("Outfeed is not supported on CPU/GPU (b/30467474)"); +} + Status GenericTransferManager::ResetDevices( tensorflow::gtl::ArraySlice executors) { @@ -174,14 +180,3 @@ int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) { } } // namespace xla - -static xla::TransferManager* CreateGenericTransferManager() { - return new xla::GenericTransferManager(se::cuda::kCudaPlatformId); -} - -static bool InitModule() { - xla::TransferManager::RegisterTransferManager(se::cuda::kCudaPlatformId, - CreateGenericTransferManager); - return true; -} -static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 06819d65c70..2fbdb94f06f 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -55,6 +55,10 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) override; + Status ResetDevices( tensorflow::gtl::ArraySlice executors) override; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index c1abf2237bd..86986934117 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -76,6 +76,7 @@ cc_library( cc_test( name = "stream_assignment_test", + size = "small", srcs = [ "stream_assignment_test.cc", ], @@ -86,7 +87,6 @@ cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", - "//tensorflow/core:test_main", ], ) @@ -96,7 +96,6 @@ cc_library( hdrs = ["hlo_to_ir_bindings.h"], deps = [ ":ir_emission_utils", - ":temp_buffer_offsets", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", @@ -127,7 +126,6 @@ cc_library( ":ir_emission_utils", ":parallel_loop_emitter", ":partition_assignment", - ":temp_buffer_offsets", ":while_transformer", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -197,23 +195,11 @@ cc_library( ], ) -cc_library( - name = "temp_buffer_offsets", - srcs = ["temp_buffer_offsets.cc"], - hdrs = ["temp_buffer_offsets.h"], - deps = [ - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/core:lib", - ], -) - cc_library( name = "buffer_allocations", srcs = ["buffer_allocations.cc"], hdrs = ["buffer_allocations.h"], deps = [ - ":temp_buffer_offsets", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -233,6 +219,7 @@ cc_library( "for_thunk.cc", "gemm_thunk.cc", "gpu_executable.cc", + "infeed_thunk.cc", "kernel_thunk.cc", "sequential_thunk.cc", "thunk_schedule.cc", @@ -245,6 +232,7 @@ cc_library( "for_thunk.h", "gemm_thunk.h", "gpu_executable.h", + "infeed_thunk.h", "kernel_thunk.h", "sequential_thunk.h", "thunk.h", @@ -254,9 +242,9 @@ cc_library( ], deps = [ ":buffer_allocations", + ":infeed_manager", ":partition_assignment", ":stream_assignment", - ":temp_buffer_offsets", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", @@ -271,13 +259,14 @@ cc_library( "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform/default/build_config:cublas_plugin", + "//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", ], ) @@ -316,6 +305,7 @@ cc_library( cc_test( name = "convolution_folding_test", + size = "small", srcs = ["convolution_folding_test.cc"], deps = [ ":convolution_folding", @@ -324,7 +314,6 @@ cc_test( "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", - "//tensorflow/core:test_main", ], ) @@ -342,12 +331,11 @@ cc_library( cc_test( name = "instruction_fusion_test", + size = "small", srcs = ["instruction_fusion_test.cc"], deps = [ ":instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:test", - "//tensorflow/core:test_main", ], ) @@ -370,31 +358,26 @@ cc_library( srcs = ["fusion_merger.cc"], hdrs = ["fusion_merger.h"], deps = [ + ":instruction_fusion", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:instruction_fusion", "//tensorflow/core:lib", ], ) cc_test( name = "fusion_merger_test", + size = "small", srcs = ["fusion_merger_test.cc"], deps = [ ":fusion_merger", ":instruction_fusion", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:test_main", ], ) @@ -430,7 +413,7 @@ cc_library( ":pad_insertion", ":partition_assignment", ":stream_assignment", - ":temp_buffer_offsets", + "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -441,13 +424,17 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", - "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:hlo_proto_util", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", @@ -461,6 +448,18 @@ cc_library( alwayslink = True, # Contains compiler registration ) +cc_library( + name = "infeed_manager", + srcs = ["infeed_manager.cc"], + hdrs = ["infeed_manager.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "layout_assignment", srcs = ["layout_assignment.cc"], @@ -479,6 +478,7 @@ cc_library( cc_test( name = "layout_assignment_test", + size = "small", srcs = ["layout_assignment_test.cc"], deps = [ ":layout_assignment", @@ -488,7 +488,6 @@ cc_test( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:test_main", ], ) @@ -508,6 +507,7 @@ cc_library( cc_test( name = "hlo_schedule_test", + size = "small", srcs = [ "hlo_schedule_test.cc", ], @@ -518,7 +518,6 @@ cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:test_main", ], ) @@ -539,19 +538,15 @@ cc_library( cc_test( name = "while_transformer_test", + size = "small", srcs = ["while_transformer_test.cc"], deps = [ ":instruction_fusion", ":while_transformer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index a9975de3f17..9fdf717b5d4 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -38,28 +38,12 @@ void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index, } StatusOr> BufferAllocations::Builder::Build( - const BufferAssignment& buffer_assignment, - const TempBufferOffsets& temp_buffer_offsets, int device_ordinal, + const BufferAssignment& buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { - se::DeviceMemoryBase temp_buffer_base; - if (temp_buffer_offsets.TotalSizeInBytes() > 0) { - TF_ASSIGN_OR_RETURN( - temp_buffer_base, - memory_allocator->Allocate(device_ordinal, - temp_buffer_offsets.TotalSizeInBytes())); - if (temp_buffer_base == nullptr) { - return ResourceExhausted( - "Out of memory when allocating %s bytes for temporary buffers.", - tensorflow::strings::HumanReadableNumBytes( - temp_buffer_offsets.TotalSizeInBytes()) - .c_str()); - } - } - auto buffer_allocations = WrapUnique(new BufferAllocations( - buffer_assignment.Allocations().size(), temp_buffer_base, device_ordinal, - memory_allocator)); + const int64 num_buffers = buffer_assignment.Allocations().size(); + auto buffer_allocations = WrapUnique( + new BufferAllocations(num_buffers, device_ordinal, memory_allocator)); - int64 num_buffers = buffer_assignment.Allocations().size(); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. @@ -68,13 +52,13 @@ StatusOr> BufferAllocations::Builder::Build( continue; } + // Allocate each allocation that might escape, or is the temp buffer. + bool seen_temp_buffer = false; const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); - if (allocation.maybe_live_out()) { - auto buffer_size = allocation.size(); + if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { + const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; if (buffer_size > 0) { - // If the buffer escapes, we need to allocate it separately instead of - // merging it into the memory block for temporary buffers. TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate( device_ordinal, buffer_size)); if (buffer_address == nullptr) { @@ -85,13 +69,14 @@ StatusOr> BufferAllocations::Builder::Build( } } buffer_allocations->SetBuffer(i, buffer_address); - } else if (allocation.IsPreallocatedTempBuffer()) { - se::DeviceMemoryBase temp_buffer_address( - /*opaque=*/static_cast( - buffer_allocations->GetTempBufferBase().opaque()) + - temp_buffer_offsets.GetOffset(i), - /*size=*/allocation.size()); - buffer_allocations->SetBuffer(i, temp_buffer_address); + if (allocation.IsPreallocatedTempBuffer()) { + if (seen_temp_buffer) { + LOG(FATAL) << "Multiple temporary buffers detected. BufferAssigner " + << "must guarantee at most one temporary buffer."; + } + seen_temp_buffer = true; + buffer_allocations->temp_buffer_base_ = buffer_address; + } } } @@ -102,22 +87,19 @@ tensorflow::Status BufferAllocations::TearDown( const std::set& live_addresses, const BufferAssignment& buffer_assignment) { // Deallocate temporary buffers. - for (auto i = 0; i < buffer_assignment.Allocations().size(); ++i) { + const int64 num_buffers = buffer_assignment.Allocations().size(); + for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index()); - if (allocation.maybe_live_out() && !live_addresses.count(buffer_address)) { - // Deallocate buffers that marked "maybe_live_out" but is not actually - // live out. + // Deallocate buffers marked "maybe_live_out" but aren't actually live out, + // and temp buffers. + if ((allocation.maybe_live_out() && + !live_addresses.count(buffer_address)) || + allocation.IsPreallocatedTempBuffer()) { TF_RETURN_IF_ERROR( memory_allocator_->Deallocate(device_ordinal_, &buffer_address)); } } - - // Deallocate the memory block for temporary buffers. - if (temp_buffer_base_ != nullptr) { - TF_RETURN_IF_ERROR( - memory_allocator_->Deallocate(device_ordinal_, &temp_buffer_base_)); - } return tensorflow::Status::OK(); } @@ -128,6 +110,16 @@ se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( return buffers_[buffer_index]; } +se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( + const BufferAllocation::Slice& buffer_slice) const { + se::DeviceMemoryBase base = GetDeviceAddress(buffer_slice.index()); + CHECK_LE(buffer_slice.offset(), base.size()); + CHECK_LE(buffer_slice.offset() + buffer_slice.size(), base.size()); + return se::DeviceMemoryBase( + static_cast(base.opaque()) + buffer_slice.offset(), + buffer_slice.size(), /*is_sub_buffer=*/true); +} + void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index, se::DeviceMemoryBase buffer) { CHECK_GE(buffer_index, 0); diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index a0cd6cac016..ea7f0eb3745 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -49,8 +48,7 @@ class BufferAllocations { // `device_ordinal` is the number of the device this function allocates // memory on. StatusOr> Build( - const BufferAssignment& buffer_assignment, - const TempBufferOffsets& temp_buffer_offsets, int device_ordinal, + const BufferAssignment& buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator); private: @@ -70,6 +68,11 @@ class BufferAllocations { perftools::gputools::DeviceMemoryBase GetDeviceAddress( BufferAllocation::Index buffer_index) const; + // Same as above, but also adjusts the returned address for the offset and + // size contained in the given slice. + perftools::gputools::DeviceMemoryBase GetDeviceAddress( + const BufferAllocation::Slice& buffer_slice) const; + perftools::gputools::DeviceMemoryBase GetTempBufferBase() const { return temp_buffer_base_; } @@ -81,12 +84,9 @@ class BufferAllocations { const BufferAssignment& buffer_assignment); private: - BufferAllocations(BufferAllocation::Index buffer_count, - perftools::gputools::DeviceMemoryBase temp_buffer_base, - int device_ordinal, DeviceMemoryAllocator* memory_allocator) + BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal, + DeviceMemoryAllocator* memory_allocator) : buffers_(buffer_count), - temp_buffer_base_( - perftools::gputools::DeviceMemory(temp_buffer_base)), device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} @@ -100,7 +100,7 @@ class BufferAllocations { std::vector buffers_; // The base address of the memory block that contains all temporary buffers. - perftools::gputools::DeviceMemory temp_buffer_base_; + perftools::gputools::DeviceMemoryBase temp_buffer_base_; int device_ordinal_; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index b407a01f0af..16febea14de 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -106,7 +106,7 @@ MatchBackwardFilter(HloInstruction* conv) { // // Compute the window of the backward convolution. Window backward_conv_window; - for (int i = 0; i < 2; ++i) { + for (int i = 0; i < spatial_dims.size(); ++i) { WindowDimension* dim = backward_conv_window.add_dimensions(); // The window size of the backward convolution equals the output size of the // forward convolution. @@ -185,7 +185,7 @@ MatchBackwardFilter(HloInstruction* conv) { ConvolutionDimensionNumbers backward_conv_dnums; backward_conv_dnums.set_batch_dimension(feature_dim); backward_conv_dnums.set_feature_dimension(batch_dim); - for (int i = 0; i < 2; ++i) { + for (int i = 0; i < spatial_dims.size(); ++i) { backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]); } // The dimension numbering of the output of the forward convolution (before @@ -201,7 +201,7 @@ MatchBackwardFilter(HloInstruction* conv) { PositionInContainer(transpose->dimensions(), batch_dim)); backward_conv_dnums.set_kernel_output_feature_dimension( PositionInContainer(transpose->dimensions(), feature_dim)); - for (int i = 0; i < 2; ++i) { + for (int i = 0; i < spatial_dims.size(); ++i) { backward_conv_dnums.add_kernel_spatial_dimensions( PositionInContainer(transpose->dimensions(), spatial_dims[i])); } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc index 83922cbe14a..ba9c70ded36 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -97,10 +97,10 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) { activations, gradients, conv_window, tf_default_dnums_for_backward_filter_)); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(&module)); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(module.get())); EXPECT_EQ(HloOpcode::kFusion, entry_computation->root_instruction()->opcode()); EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == @@ -126,9 +126,9 @@ TEST_F(ConvolutionFoldingTest, activations, gradients, conv_window, tf_default_dnums_for_backward_filter_)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(&module)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_FALSE(FoldConvolution(module.get())); } // Extracted from block35 training. @@ -155,10 +155,10 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0})); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(&module)); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(module.get())); EXPECT_EQ(HloOpcode::kFusion, entry_computation->root_instruction()->opcode()); EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == @@ -189,10 +189,10 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0})); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(&module)); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(module.get())); EXPECT_EQ(HloOpcode::kFusion, entry_computation->root_instruction()->opcode()); EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == @@ -222,10 +222,10 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0})); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(&module)); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(module.get())); EXPECT_EQ(HloOpcode::kFusion, entry_computation->root_instruction()->opcode()); EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == @@ -269,10 +269,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { output->shape(), reverse_kernel->shape(), conv_window, conv_dnums) .ValueOrDie())); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(&module)); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(module.get())); EXPECT_EQ(HloOpcode::kFusion, entry_computation->root_instruction()->opcode()); EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == @@ -313,10 +313,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { /*lhs=*/output, /*rhs=*/kernel, conv_window, tf_default_dnums_for_backward_input_)); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(&module)); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(module.get())); EXPECT_EQ(HloOpcode::kFusion, entry_computation->root_instruction()->opcode()); EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == @@ -346,9 +346,9 @@ TEST_F(ConvolutionFoldingTest, /*lhs=*/output, /*rhs=*/kernel, default_conv_window_, tf_default_dnums_for_backward_input_)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(&module)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_FALSE(FoldConvolution(module.get())); } // Extracted from Inception V3 training. @@ -394,10 +394,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { tf_default_dnums_for_backward_input_) .ValueOrDie())); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(&module)); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(module.get())); EXPECT_EQ(HloOpcode::kFusion, entry_computation->root_instruction()->opcode()); EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == @@ -441,9 +441,9 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { tf_default_dnums_for_backward_input_) .ValueOrDie())); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(&module)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_FALSE(FoldConvolution(module.get())); } // Extracted from //learning/brain/google/xla/benchmarks/resnet.py @@ -490,10 +490,10 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_input_) .ValueOrDie())); - HloModule module(TestName()); + auto module = CreateNewModule(); const HloComputation* entry_computation = - module.AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(&module)); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(FoldConvolution(module.get())); const HloInstruction* backward_conv = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kFusion, backward_conv->opcode()); EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == @@ -543,10 +543,14 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_input_) .ValueOrDie())); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(&module)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_FALSE(FoldConvolution(module.get())); } } // namespace gpu } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 30a92ab3130..9a0b14eb733 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -29,7 +29,6 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { -using Index = BufferAllocation::Index; using se::dnn::BatchDescriptor; using se::dnn::ConvolutionDescriptor; using se::dnn::DataLayout; @@ -92,12 +91,15 @@ string ConvolutionKindToString( case ConvolutionThunk::ConvolutionKind::kBackwardInput: return "backward_input"; } + return "unknown convolution kind"; } ConvolutionThunk::ConvolutionThunk( - ConvolutionKind convolution_kind, Index input_buffer, Index filter_buffer, - Index output_buffer, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, + ConvolutionKind convolution_kind, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& filter_buffer, + const BufferAllocation::Slice& output_buffer, const Shape& input_shape, + const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dim_nums, const HloInstruction* hlo) : Thunk(Kind::kConvolution, hlo), convolution_kind_(convolution_kind), @@ -119,50 +121,78 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( VLOG(3) << "Dim nums: { " << dim_nums_.ShortDebugString() << " }"; VLOG(3) << "Window: { " << window_.ShortDebugString() << " }"; + const int num_dimensions = window_.dimensions_size(); + CHECK_LE(num_dimensions, 3); + // cuDNN does not support 1D convolutions. We therefore express 1D + // convolutions as 2D convolutions where the first spatial dimension is 1. + // This matches the behavior of TF (see definition of conv1d in + // tensorflow/python/ops/nn_ops.py). + const int effective_num_dimensions = std::max(2, num_dimensions); + CHECK_EQ(F32, output_shape_.element_type()); - CHECK_EQ(2, window_.dimensions_size()); + CHECK_EQ(num_dimensions, dim_nums_.spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size()); for (const WindowDimension& dim : window_.dimensions()) { CHECK_EQ(dim.padding_low(), dim.padding_high()); } - const WindowDimension& height = window_.dimensions(0); - const WindowDimension& width = window_.dimensions(1); // cuDNN's convolution APIs support the BDYX layout for activations/output and // the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls - // when we switch to cuDNN v5. - BatchDescriptor input_descriptor; + BatchDescriptor input_descriptor(effective_num_dimensions); input_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_height(input_shape_.dimensions(dim_nums_.spatial_dimensions(0))) - .set_width(input_shape_.dimensions(dim_nums_.spatial_dimensions(1))) .set_feature_map_count( input_shape_.dimensions(dim_nums_.feature_dimension())) .set_count(input_shape_.dimensions(dim_nums_.batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + // Note that the dimensions are reversed. The same holds below. + input_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + input_shape_.dimensions(dim_nums_.spatial_dimensions(dim))); + } - FilterDescriptor filter_descriptor; + FilterDescriptor filter_descriptor(effective_num_dimensions); filter_descriptor.set_layout(FilterLayout::kOutputInputYX) .set_input_feature_map_count( filter_shape_.dimensions(dim_nums_.kernel_input_feature_dimension())) - .set_output_feature_map_count( - filter_shape_.dimensions(dim_nums_.kernel_output_feature_dimension())) - .set_input_filter_height( - filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(0))) - .set_input_filter_width( - filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(1))); + .set_output_feature_map_count(filter_shape_.dimensions( + dim_nums_.kernel_output_feature_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + filter_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(dim))); + } - ConvolutionDescriptor convolution_descriptor; - convolution_descriptor.set_zero_padding_width(width.padding_low()) - .set_zero_padding_height(height.padding_low()) - .set_horizontal_filter_stride(width.stride()) - .set_vertical_filter_stride(height.stride()); + ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + for (int dim = 0; dim < num_dimensions; ++dim) { + convolution_descriptor + .set_zero_padding( + static_cast(effective_num_dimensions - dim - 1), + window_.dimensions(dim).padding_low()) + .set_filter_stride( + static_cast(effective_num_dimensions - dim - 1), + window_.dimensions(dim).stride()); + } - BatchDescriptor output_descriptor; + BatchDescriptor output_descriptor(effective_num_dimensions); output_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_height(output_shape_.dimensions(dim_nums_.spatial_dimensions(0))) - .set_width(output_shape_.dimensions(dim_nums_.spatial_dimensions(1))) .set_feature_map_count( output_shape_.dimensions(dim_nums_.feature_dimension())) .set_count(output_shape_.dimensions(dim_nums_.batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + output_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + output_shape_.dimensions(dim_nums_.spatial_dimensions(dim))); + } + + // Add a singleton dimension in the 1D convolution case. + if (num_dimensions == 1) { + input_descriptor.set_spatial_dim(static_cast(0), 1); + output_descriptor.set_spatial_dim(static_cast(0), 1); + filter_descriptor.set_spatial_dim(static_cast(0), 1); + convolution_descriptor + .set_zero_padding(static_cast(0), 0) + .set_filter_stride(static_cast(0), 1); + } se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); @@ -228,15 +258,21 @@ tensorflow::Status ConvolutionThunk::Convolve( std::vector ConvolutionThunk::GetAlgorithms( se::StreamExecutor* stream_exec) const { std::vector algorithms; + // TODO(yangzihao): Currently disable the use of winograd nonfused in XLA + // by default. Should send in conv parameters and enable it when + // ShouldIncludeWinogradNonfusedAlgo() returns true. switch (convolution_kind_) { case ConvolutionKind::kBackwardFilter: - CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(&algorithms)); + CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( + /*with_winograd_nonfused=*/false, &algorithms)); break; case ConvolutionKind::kBackwardInput: - CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(&algorithms)); + CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( + /*with_winograd_nonfused=*/false, &algorithms)); break; case ConvolutionKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(&algorithms)); + CHECK(stream_exec->GetConvolveAlgorithms(/*with_winograd_nonfused=*/false, + &algorithms)); break; } return algorithms; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index cd9568f6a25..aaf72935e61 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -70,9 +70,9 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. // Semantics of null hlo_instruction argument are as in Thunk. ConvolutionThunk(ConvolutionKind convolution_kind, - BufferAllocation::Index input_buffer, - BufferAllocation::Index filter_buffer, - BufferAllocation::Index output_buffer, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& filter_buffer, + const BufferAllocation::Slice& output_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, @@ -125,19 +125,19 @@ class ConvolutionThunk : public Thunk { // the best algorithm from some heuristics based on its parameters. perftools::gputools::dnn::AlgorithmConfig best_algorithm_; - ConvolutionKind convolution_kind_; + const ConvolutionKind convolution_kind_; - BufferAllocation::Index input_buffer_; - BufferAllocation::Index filter_buffer_; - BufferAllocation::Index output_buffer_; + const BufferAllocation::Slice input_buffer_; + const BufferAllocation::Slice filter_buffer_; + const BufferAllocation::Slice output_buffer_; - Shape input_shape_; - Shape filter_shape_; - Shape output_shape_; + const Shape input_shape_; + const Shape filter_shape_; + const Shape output_shape_; - Window window_; + const Window window_; - ConvolutionDimensionNumbers dim_nums_; + const ConvolutionDimensionNumbers dim_nums_; }; string ConvolutionKindToString( diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index 76fb079bd4d..87858e94090 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -21,7 +21,7 @@ namespace xla { namespace gpu { CopyThunk::CopyThunk(const void* source_address, - BufferAllocation::Index destination_buffer, + const BufferAllocation::Slice& destination_buffer, uint64 mem_size, const HloInstruction* hlo_instruction) : Thunk(Kind::kCopy, hlo_instruction), source_address_(source_address), diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 803e699bfdd..6b8c432715f 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -34,7 +34,7 @@ class CopyThunk : public Thunk { // device buffer `destination_buffer`. `mem_size` is the size of the data in // bytes. CopyThunk(const void* source_address, - BufferAllocation::Index destination_buffer, uint64 mem_size, + const BufferAllocation::Slice& destination_buffer, uint64 mem_size, const HloInstruction* hlo_instruction); CopyThunk(const CopyThunk&) = delete; @@ -46,8 +46,8 @@ class CopyThunk : public Thunk { private: const void* source_address_; - BufferAllocation::Index destination_buffer_; - uint64 mem_size_; + const BufferAllocation::Slice destination_buffer_; + const uint64 mem_size_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 67c80bf93b1..2987c8913d7 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -113,7 +113,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, PrimitiveType output_type) const { - // Binary math functions tranform are of type [T] -> T. + // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { return Unimplemented("Input type ≠ output type: %s ≠ %s", @@ -175,7 +175,7 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( return make_sqrt(); } - if (!hlo_module_config_.fast_math_disabled() && + if (hlo_module_config_.debug_options().xla_enable_fast_math() && IsFPLiteralWithValue(rhs, -.5)) { VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString(); // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX @@ -270,69 +270,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const { switch (hlo->opcode()) { - case HloOpcode::kPad: - return [=, &operand_to_generator]( - const IrArray::Index& padded_index) -> StatusOr { - auto index = padded_index; - llvm::Value* in_bounds = - llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1); - for (int i = 0; i < index.size(); ++i) { - auto index_typed_const = [=](int64 n) { - return llvm::ConstantInt::get(index[i]->getType(), n); - }; - const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = ir_builder_->CreateSub( - index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpEQ( - index_typed_const(0), - ir_builder_->CreateURem( - index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = ir_builder_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), - "in_bounds"); - } - - // if (in_bounds) { - // ret_value = operand0[index]; // source - // } else { - // ret_value = *operand1; // padding - // } - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_), - "pad_result_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))(index)); - ir_builder_->CreateStore(operand_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); - ir_builder_->CreateStore(padding_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - // Don't create phi(operand_value, padding_value) here, because invoking - // operand_to_generator may create new basic blocks, making the parent - // of operand_value or padding_value no longer a predecessor of - // if_data.after_block. - return ir_builder_->CreateLoad(ret_value_addr); - }; case HloOpcode::kMap: return [=, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index fb053b62a75..afb78b8300b 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -16,9 +16,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include +#include +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" -#include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -98,8 +99,9 @@ double CalculateFlopsToBytesRatio(HloInstruction* fusion) { double bytes = CalculateBytesReadByFusionInstruction(fusion); // Add bytes written to root instructions buffer. bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); - // Calculate flops for all fused instructions. - HloCostAnalysis analysis; + // Calculate flops for all fused instructions. Use a null shape size function + // because we don't care about bytes accessed by the ops. + HloCostAnalysis analysis([](const Shape& shape) { return 0; }); TF_CHECK_OK(fusion->fused_expression_root()->Accept(&analysis)); // Return flops / bytes. return bytes > 0.0 ? analysis.flop_count() / bytes : analysis.flop_count(); @@ -219,7 +221,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { fusion->fused_instructions().end(), [](const std::unique_ptr& instruction) { if (instruction->opcode() != HloOpcode::kParameter && - IsExpensive(*instruction)) { + GpuInstructionFusion::IsExpensive(*instruction)) { return false; } return true; @@ -248,7 +250,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { return Status::OK(); } // Merge fused instructions from 'fusion' into each user. - std::set users = fusion->users(); + std::vector users = fusion->users(); for (HloInstruction* user : users) { user->MergeFusionInstruction(fusion); changed_ = true; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 9a989d26f93..bd720f8584f 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -25,7 +25,7 @@ namespace gpu { // An HLO pass that attempts to merge fusion instructions to reduce kernel // launch overhead and improve data locality. // -// Fusion instructions are merged into their users if two conditons are met: +// Fusion instructions are merged into their users if two conditions are met: // // 1) The flops_to_bytes ratio of the fusion instruction is below the threshold // value of 1.0. diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index a87e66ca869..8afc32dea97 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -25,7 +25,7 @@ namespace { class FusionMergerTest : public HloTestBase { protected: - FusionMergerTest() : module_(TestName()) {} + FusionMergerTest() : module_(CreateNewModule()) {} // Builds the following computation: // @@ -86,7 +86,7 @@ class FusionMergerTest : public HloTestBase { // Create output Tuple. builder.AddInstruction(HloInstruction::CreateTuple({out0, out1, out2})); - return module_.AddEntryComputation(builder.Build()); + return module_->AddEntryComputation(builder.Build()); } // Builds the following computation: @@ -154,7 +154,7 @@ class FusionMergerTest : public HloTestBase { // Create output Tuple. builder.AddInstruction(HloInstruction::CreateTuple({out0, out1})); - return module_.AddEntryComputation(builder.Build()); + return module_->AddEntryComputation(builder.Build()); } // Builds the following computation: @@ -225,7 +225,7 @@ class FusionMergerTest : public HloTestBase { // Create output Tuple. builder.AddInstruction(HloInstruction::CreateTuple({out0, out1})); - return module_.AddEntryComputation(builder.Build()); + return module_->AddEntryComputation(builder.Build()); } Shape data_shape_ = ShapeUtil::MakeShape(F32, {4}); @@ -235,7 +235,7 @@ class FusionMergerTest : public HloTestBase { Shape tuple_shape4_ = ShapeUtil::MakeTupleShape( {data_shape_, data_shape_, data_shape_, data_shape_}); - HloModule module_; + std::unique_ptr module_; }; // Tests that we can merge a fusion instruction that is below threshold. @@ -278,13 +278,15 @@ class FusionMergerTest : public HloTestBase { TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { auto computation = BuildComputation0(); // Run standard fusion passes. - EXPECT_TRUE( - GpuInstructionFusion(/*may_duplicate=*/false).Run(&module_).ValueOrDie()); - EXPECT_FALSE( - GpuInstructionFusion(/*may_duplicate=*/true).Run(&module_).ValueOrDie()); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .ValueOrDie()); + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .ValueOrDie()); // Run fusion merger pass, which should merge the shared fusion instruction // into its two users. - EXPECT_TRUE(FusionMerger().Run(&module_).ValueOrDie()); + EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie()); auto* root = computation->root_instruction(); EXPECT_EQ(HloOpcode::kTuple, root->opcode()); @@ -338,14 +340,16 @@ TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { BuildComputation1(); // Run standard fusion passes. - EXPECT_TRUE( - GpuInstructionFusion(/*may_duplicate=*/false).Run(&module_).ValueOrDie()); - EXPECT_FALSE( - GpuInstructionFusion(/*may_duplicate=*/true).Run(&module_).ValueOrDie()); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .ValueOrDie()); + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .ValueOrDie()); // Run fusion merger pass, which should detect that the flops/bytes of the // shared fusion instruction exceeds the threshold ratio, and therefore // cannot be merged with other fusion instructions. - EXPECT_FALSE(FusionMerger().Run(&module_).ValueOrDie()); + EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie()); } // Tests that threshold for bytes transferred if merged is exceeded. @@ -388,13 +392,15 @@ TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { BuildComputation2(/*add_extra_input=*/true); // Run standard fusion passes. - EXPECT_TRUE( - GpuInstructionFusion(/*may_duplicate=*/false).Run(&module_).ValueOrDie()); - EXPECT_FALSE( - GpuInstructionFusion(/*may_duplicate=*/true).Run(&module_).ValueOrDie()); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .ValueOrDie()); + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .ValueOrDie()); // Run fusion merger pass, which should detect that the net bytes transferred // (if merged) would increase. - EXPECT_FALSE(FusionMerger().Run(&module_).ValueOrDie()); + EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie()); } // Tests that threshold for bytes transferred if merged is not exceeded. @@ -442,15 +448,21 @@ TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) { BuildComputation2(/*add_extra_input=*/false); // Run standard fusion passes. - EXPECT_TRUE( - GpuInstructionFusion(/*may_duplicate=*/false).Run(&module_).ValueOrDie()); - EXPECT_FALSE( - GpuInstructionFusion(/*may_duplicate=*/true).Run(&module_).ValueOrDie()); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .ValueOrDie()); + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .ValueOrDie()); // Run fusion merger pass, which should detect that the net bytes transferred // (if merged) would not increase. - EXPECT_TRUE(FusionMerger().Run(&module_).ValueOrDie()); + EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie()); } } // namespace } // namespace gpu } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 98a8a4a2b1c..e784046450e 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -27,8 +27,6 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { -using Index = BufferAllocation::Index; - namespace { // This struct contains the metadata of a matrix, e.g., its base address and @@ -47,63 +45,171 @@ struct MatrixDescriptor { int64 num_cols; }; -// Performs a gemm call on lhs_matrix and rhs_matrix and stores the result to -// output_matrix. +// Performs a gemm call without an explicit algorithm on lhs_matrix and +// rhs_matrix, and stores the result to output_matrix. template -tensorflow::Status DoGemm(MatrixDescriptor lhs_matrix, - MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { +bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { DCHECK(!output_matrix.transpose); se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); - bool launch_ok = - stream - ->ThenBlasGemm( - lhs_matrix.transpose ? se::blas::Transpose::kTranspose - : se::blas::Transpose::kNoTranspose, - rhs_matrix.transpose ? se::blas::Transpose::kTranspose - : se::blas::Transpose::kNoTranspose, - output_matrix.num_rows, output_matrix.num_cols, - lhs_matrix.transpose - ? lhs_matrix.num_rows - : lhs_matrix.num_cols, // Size of the reduce dimension. - /*alpha=*/1.0, - lhs_data, - lhs_matrix.num_rows, // The leading dimension of LHS. - rhs_data, - rhs_matrix.num_rows, // The leading dimension of RHS. - /*beta=*/0.0, &output_data, - output_matrix - .num_rows) // The leading dimension of the output matrix. - .ok(); - if (!launch_ok) { - return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); - } - return tensorflow::Status::OK(); + auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose; + auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose; + auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; + + return stream + ->ThenBlasGemm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + &output_data, /*leading dim of output=*/output_matrix.num_rows) + .ok(); } -// Return, if the given type is a valid Gemm elemental type, the executor for -// that type, else null. -// TODO(b/27202055): consider more element types. -std::function -FindGemmExecutor(PrimitiveType type) { +// Like DoGemm, but takes an explicit computation type and algorithm. +// computation_type specifies the type of intermediate values generated during +// the matmul (e.g. your input/output matricies could be f16s but you could do +// computations with f32s). algorithm is an opaque identifier which functions +// as a hint to cublas. +// +// Not all algorithms are valid for all matrix sizes, and not all CUDA versions +// and GPUs even support gemm-with-algorithm. So expect that this may fail +// unless you've already checked that it works for this particular GPU + input +// size. +// +// If you pass a non-null ProfileResult, this will always return true (assuming +// the Stream was valid to begin with); check the is_valid property of the +// ProfileResult to see whether the call actually succeeded. +template +bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, + MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, + se::blas::ComputationType computation_type, + se::blas::AlgorithmType algorithm, se::Stream* stream, + se::blas::ProfileResult* output_profile_result) { + DCHECK(!output_matrix.transpose); + + se::DeviceMemory lhs_data(lhs_matrix.data); + se::DeviceMemory rhs_data(rhs_matrix.data); + se::DeviceMemory output_data(output_matrix.data); + + auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose; + auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose + : se::blas::Transpose::kNoTranspose; + auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; + + return stream + ->ThenBlasGemmWithAlgorithm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + &output_data, /*leading dim of output=*/output_matrix.num_rows, + computation_type, algorithm, output_profile_result) + .ok(); +} + +// Experimentally tries to pick the best algorithm for the given gemm. +// +// This may fail under perfectly normal circumstances. In particular, it will +// fail if the program was built with < CUDA 8 or if we're using a gpu older +// than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at +// all. +template +StatusOr DoGemmAutotune( + MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::blas::ComputationType computation_type, + se::Stream* stream) { + std::vector algorithms; + CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms)); + + se::blas::ProfileResult best_result; + for (auto algorithm : algorithms) { + se::blas::ProfileResult profile_result; + // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail + // for all algorithms if we're targeting < sm_50. But because we pass a + // non-null ProfileResult, DoGemmWithAlgorithm should always return true, + // and the actual success-ness is returned in ProfileResult::is_valid. + DCHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, + computation_type, algorithm, stream, + &profile_result)); + + if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + } + + if (best_result.is_valid()) { + return best_result.algorithm(); + } + + return InternalError( + "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms " + "ran successfully", + stream, algorithms.size()); +} + +// Helper functions to go from a PrimitiveType to a templated version of +// DoGemm/DoGemmWithAlgorithm/DoGemmAutotune. +auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { switch (type) { case F32: return &DoGemm; case F64: return &DoGemm; default: - return nullptr; + LOG(FATAL) << "Unsupported type."; + } +} +auto GetGemmWithAlgorithmFn(PrimitiveType type) + -> decltype(&DoGemmWithAlgorithm) { + switch (type) { + case F32: + return &DoGemmWithAlgorithm; + case F64: + return &DoGemmWithAlgorithm; + default: + LOG(FATAL) << "Unsupported type."; + } +} +auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { + switch (type) { + case F32: + return &DoGemmAutotune; + case F64: + return &DoGemmAutotune; + default: + LOG(FATAL) << "Unsupported type."; + } +} + +// Converts from an XLA PrimitiveType to a blas::ComputationType, which is used +// to specify the precision with which matmul computations should be performed, +// separately from the precision of the inputs and result. +se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { + switch (type) { + case F32: + return se::blas::ComputationType::kF32; + case F64: + return se::blas::ComputationType::kF64; + default: + LOG(FATAL) << "Unsupported type."; } } } // namespace -GemmThunk::GemmThunk(Index lhs_buffer, Index rhs_buffer, Index output_buffer, +GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, + const BufferAllocation::Slice& rhs_buffer, + const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, bool transpose_lhs, bool transpose_rhs, const HloInstruction* hlo_instruction) @@ -120,8 +226,6 @@ GemmThunk::GemmThunk(Index lhs_buffer, Index rhs_buffer, Index output_buffer, tensorflow::Status GemmThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { VLOG(2) << "Executing a GemmThunk"; - auto executor = FindGemmExecutor(output_shape_.element_type()); - DCHECK(executor != nullptr); se::DeviceMemoryBase lhs_data = buffer_allocations.GetDeviceAddress(lhs_buffer_); @@ -141,7 +245,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( // Therefore, we need to convert dot between row-major matrices to that // between column-major matrices. The key insight for the conversion is that, // in linear storage, matrix M in column-major order is identical to the - // tranpose of M in row-major order. In other words, + // transpose of M in row-major order. In other words, // // column-major(M) = row-major(M^T). // @@ -172,17 +276,66 @@ tensorflow::Status GemmThunk::ExecuteOnStream( make_descriptor(lhs_data, lhs_shape_, transpose_lhs_); const MatrixDescriptor rhs_descriptor = make_descriptor(rhs_data, rhs_shape_, transpose_rhs_); + + // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to + // autotune this gemm to figure out the best algorithm. + auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { + PrimitiveType element_type = output_shape_.element_type(); + se::blas::ComputationType computation_type = + GetBlasComputationType(element_type); + + const string& device_name = stream->parent()->GetDeviceDescription().name(); + auto autotune_it = autotune_results_.find(device_name); + if (autotune_it == autotune_results_.end()) { + StatusOr best_algorithm = + GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, + computation_type, stream); + autotune_it = + autotune_results_.insert({device_name, best_algorithm}).first; + + if (autotune_it->second.ok()) { + VLOG(2) << "Autotune on GemmThunk " << this + << " successful; best algorithm is " + << best_algorithm.ValueOrDie(); + } else { + VLOG(2) << "Autotune on GemmThunk " << this + << " unsuccessful. Will use generic gemm."; + } + } + + const StatusOr& best_algorithm = + autotune_it->second; + if (best_algorithm.ok()) { + auto algorithm = best_algorithm.ValueOrDie(); + VLOG(2) << "Using algorithm " << algorithm + << " chosen by autotuning on GemmThunk " << this; + return GetGemmWithAlgorithmFn(element_type)( + lhs_matrix, rhs_matrix, output_matrix, computation_type, algorithm, + stream, + /*output_profile_result=*/nullptr); + } + return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, + stream); + }; + + bool launch_ok; if (output_shape_.layout().minor_to_major(0) == 0) { - return executor( + launch_ok = launch( lhs_descriptor, rhs_descriptor, MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), stream); } else { - return executor( + launch_ok = launch( rhs_descriptor, lhs_descriptor, MatrixDescriptor(output_data, false, output_num_cols, output_num_rows), stream); } + + if (!launch_ok) { + return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); + } + return tensorflow::Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 7c8574d2752..983cb872924 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -37,11 +37,11 @@ class GemmThunk : public Thunk { // Constructs a thunk that computes "output = lhs rhs" using BLAS gemm. // transpose_lhs and transpose_rhs indicate whether gemm should transpose the // lhs and rhs operand. hlo_instruction is as in Thunk. - GemmThunk(BufferAllocation::Index lhs_buffer, - BufferAllocation::Index rhs_buffer, - BufferAllocation::Index output_buffer, const Shape& lhs_shape, - const Shape& rhs_shape, const Shape& output_shape, - bool transpose_lhs, bool transpose_rhs, + GemmThunk(const BufferAllocation::Slice& lhs_buffer, + const BufferAllocation::Slice& rhs_buffer, + const BufferAllocation::Slice& output_buffer, + const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape, bool transpose_lhs, bool transpose_rhs, const HloInstruction* hlo_instruction); GemmThunk(const GemmThunk&) = delete; @@ -53,16 +53,24 @@ class GemmThunk : public Thunk { perftools::gputools::Stream* stream) override; private: - BufferAllocation::Index lhs_buffer_; - BufferAllocation::Index rhs_buffer_; - BufferAllocation::Index output_buffer_; + const BufferAllocation::Slice lhs_buffer_; + const BufferAllocation::Slice rhs_buffer_; + const BufferAllocation::Slice output_buffer_; - Shape lhs_shape_; - Shape rhs_shape_; - Shape output_shape_; + const Shape lhs_shape_; + const Shape rhs_shape_; + const Shape output_shape_; - bool transpose_lhs_; - bool transpose_rhs_; + const bool transpose_lhs_; + const bool transpose_rhs_; + + // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune + // results. The map's value is the best algorithm we've found for this thunk + // on this device, or an error if none of the algorithms worked and we should + // use the regular gemm without an algorithm. + std::unordered_map> + autotune_results_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index a7e5c5226f9..86137a569f9 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -24,10 +24,12 @@ limitations under the License. #include "external/llvm/include/llvm/IR/LLVMContext.h" #include "external/llvm/include/llvm/IR/Module.h" #include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" #include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" @@ -42,15 +44,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" @@ -79,6 +84,13 @@ const char* kTargetTriple = "nvptx64-nvidia-cuda"; // NVPTXTargetMachine.cpp. const char* kDataLayout = "e-i64:64-v16:16-v32:32-n16:32:64"; +// Any address of a variable residing in global memory or returned by one of the +// memory allocation routines from the driver or runtime API is always aligned +// to at least 256 bytes. +// +// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses +constexpr int64 kMemoryAlignment = 256; + // Returns the directory containing nvvm libdevice files. This function is // called in GpuCompiler's constructor, so can't return an error. But // GpuCompiler::Compile will return an error when the wanted libdevice file @@ -114,6 +126,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, const se::DeviceDescription& device_desc) { { HloPassPipeline pipeline("optimization", dump_hlo); + pipeline.AddInvariantChecker(); { auto& pass = pipeline.AddPass>( "simplification", dump_hlo); @@ -121,10 +134,16 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); pass.AddPass(); + pass.AddPass(); } pipeline.AddPass(); - pipeline.AddPass(ImplementedAsGemm); - pipeline.AddPass(); + pipeline.AddPass( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return ImplementedAsGemm(dot) ? candidate_operands + : TransposeFolding::OperandIndices{}; + }, + TransposeFolding::NeverFoldTranspose); pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); @@ -141,17 +160,17 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. tensorflow::Status PrepareHloModuleForIrEmitting( - const Compiler::HloDumper& dump_hlo, HloModule* hlo_module, - HloModuleConfig* module_config) { + const Compiler::HloDumper& dump_hlo, HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare", dump_hlo); + pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass( - module_config->mutable_entry_computation_layout()); + hlo_module->mutable_entry_computation_layout()); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -161,16 +180,20 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which // materializes the value) or missing a necessary copy (later pass removes an - // instruction which materializes a value). + // instruction which materializes a value). DCE must be run immediately before + // (and sometime after) copy insertion, to avoid dead code from interfering + // with the rewrites. + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } // Invokes the ptxas tool on the given PTX string, and dumps its output. void DumpPtxasInfo(const string& ptx) { - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); - const string ptxas_path = flags->xla_ptxas_path; + const string ptxas_path = + tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas"); // Do not log PTX stats if ptxas is not found at the given path. if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) { LOG(WARNING) @@ -206,18 +229,18 @@ void DumpPtxasInfo(const string& ptx) { } // namespace -GpuCompiler::GpuCompiler() : libdevice_dir_(GetLibdeviceDir()) {} +GpuCompiler::GpuCompiler() + : libdevice_dir_(GetLibdeviceDir()), + pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} StatusOr> GpuCompiler::Compile( - std::unique_ptr hlo_module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(OptimizeHloModule(hlo_module.get(), dump_hlo, + TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), dump_hlo, stream_exec->GetDeviceDescription())); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(dump_hlo, hlo_module.get(), - module_config.get())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(dump_hlo, module.get())); llvm::LLVMContext llvm_context; std::string buffer; @@ -230,42 +253,45 @@ StatusOr> GpuCompiler::Compile( }; llvm_context.setDiagnosticHandler(DiagnosticHandler, &printer); - llvm::Module llvm_module(hlo_module->name().c_str(), llvm_context); + llvm::Module llvm_module(module->name().c_str(), llvm_context); // Set the target triple and the data layout. llvm_module.setTargetTriple(kTargetTriple); llvm_module.setDataLayout(kDataLayout); - const llvm::DataLayout& data_layout = llvm_module.getDataLayout(); - int64 pointer_size = data_layout.getPointerSize(); // Determine the HLO schedule, which is an ordering of HLO instructions. This // is used by buffer assignment to enable buffer reuse, and the same ordering // must also be used to determine the thunk launch schedule. - std::unique_ptr stream_assignment = - AssignStreams(*hlo_module); - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_schedule, - HloSchedule::Build(*hlo_module, *stream_assignment)); + std::unique_ptr stream_assignment = AssignStreams(*module); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_schedule, + HloSchedule::Build(*module, *stream_assignment, pointer_size_)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, - BufferAssigner::Run(hlo_module.get(), hlo_schedule->ConsumeHloOrdering(), - pointer_size)); - auto temp_buffer_offsets = MakeUnique(*buffer_assignment); + BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), kMemoryAlignment)); - IrEmitterContext ir_emitter_context( - hlo_module.get(), buffer_assignment.get(), temp_buffer_offsets.get(), - &stream_exec->GetDeviceDescription(), &llvm_module); + legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); + if (!flags->xla_gpu_dump_debug_json_to.empty()) { + HloProto proto = MakeHloProto(*module, *buffer_assignment); + TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( + proto, flags->xla_gpu_dump_debug_json_to, module->name())); + } - HloComputation* entry_computation = hlo_module->entry_computation(); - IrEmitterUnnested ir_emitter(*module_config, entry_computation, - module_config->has_hybrid_result(), + IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), + &stream_exec->GetDeviceDescription(), + &llvm_module); + + HloComputation* entry_computation = module->entry_computation(); + IrEmitterUnnested ir_emitter(module->config(), entry_computation, + module->config().has_hybrid_result(), &ir_emitter_context); TF_RETURN_IF_ERROR( entry_computation->root_instruction()->Accept(&ir_emitter)); string ir_module_string_before_opt; - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); if (VLOG_IS_ON(2) || flags->xla_gpu_embed_ir) { ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); VLOG(2) << "LLVM module before optimizations:"; @@ -279,8 +305,16 @@ StatusOr> GpuCompiler::Compile( generated_ptxes_.emplace_back(MakeUnique()); ptx = generated_ptxes_.back().get(); } - TF_ASSIGN_OR_RETURN( - *ptx, CompileToPtx(&llvm_module, *module_config, libdevice_dir_)); + int cc_major, cc_minor; + if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor)) { + LOG(WARNING) + << "Couldn't get compute capability for device; assuming sm_20."; + cc_major = 2; + cc_minor = 0; + } + TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, + module->config(), libdevice_dir_)); VLOG(2) << "LLVM module after optimizations:"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module)); @@ -297,9 +331,8 @@ StatusOr> GpuCompiler::Compile( XLA_VLOG_LINES(2, thunk_schedule->ToString()); auto* gpu_executable = - new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(hlo_module), - std::move(module_config), std::move(buffer_assignment), - std::move(temp_buffer_offsets)); + new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(module), + std::move(buffer_assignment), ShapeSizeBytesFunction()); if (flags->xla_gpu_embed_ir) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); @@ -308,9 +341,8 @@ StatusOr> GpuCompiler::Compile( } StatusOr>> GpuCompiler::Compile( - std::vector> hlo_modules, - std::vector> module_configs, - HloDumper dump_hlos, std::vector stream_execs) { + std::vector> modules, HloDumper dump_hlos, + std::vector stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on GPU."); } @@ -318,7 +350,6 @@ StatusOr>> GpuCompiler::Compile( StatusOr>> GpuCompiler::CompileAheadOfTime( std::vector> module, - std::vector> module_config, HloDumper dump_hlo, const AotCompilationOptions& options) { return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index a074607760f..da52f5ab1f8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -42,24 +41,28 @@ class GpuCompiler : public Compiler { ~GpuCompiler() override {} StatusOr> Compile( - std::unique_ptr hlo_module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( - std::vector> hlo_module, - std::vector> module_config, - HloDumper dump_hlo, + std::vector> modules, HloDumper dump_hlo, std::vector stream_exec) override; StatusOr>> CompileAheadOfTime( std::vector> module, - std::vector> module_config, HloDumper dump_hlo, AotCompilationOptions const& options) override; perftools::gputools::Platform::Id PlatformId() const override; + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { + // Capture just the pointer size, not the entire GpuCompiler object. + int64 pointer_size = pointer_size_; + return [pointer_size](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, pointer_size); + }; + } + private: // The parent directory of libdevice IR libraries. const string libdevice_dir_; @@ -70,6 +73,9 @@ class GpuCompiler : public Compiler { tensorflow::mutex mutex_; std::vector> generated_ptxes_ GUARDED_BY(mutex_); + // The size in bytes of a pointer. Used by ShapeSizeBytesFunction. + int64 pointer_size_; + TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index f654ffd22d5..7f9e60460c2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -47,8 +47,12 @@ class HloExecutionProfiler { public: // If profiling is enabled, start an execution timer running. explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, - se::Stream* stream) - : do_profile_(do_profile), profile_(profile), stream_(stream) { + se::Stream* stream, + const HloComputation* computation) + : do_profile_(do_profile), + profile_(profile), + stream_(stream), + computation_(computation) { if (do_profile_) { clock_rate_ghz_ = stream->parent()->GetDeviceDescription().clock_rate_ghz(); @@ -66,8 +70,8 @@ class HloExecutionProfiler { if (do_profile_) { stream_->ThenStopTimer(execution_timer_.get()); stream_->BlockHostUntilDone(); - profile_->set_total_cycles_executed(execution_timer_->Nanoseconds() * - clock_rate_ghz_); + profile_->set_total_cycles_executed( + *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_); } } @@ -94,6 +98,7 @@ class HloExecutionProfiler { double clock_rate_ghz_; HloExecutionProfile* profile_; se::Stream* stream_; + const HloComputation* computation_; std::unique_ptr execution_timer_; std::unique_ptr per_op_timer_; }; @@ -105,30 +110,33 @@ class HloExecutionProfiler { GpuExecutable::GpuExecutable( tensorflow::StringPiece ptx, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, - std::unique_ptr module_config, std::unique_ptr assignment, - std::unique_ptr temp_buffer_offsets) - : Executable(std::move(hlo_module), std::move(module_config)), + HloCostAnalysis::ShapeSizeFunction shape_size_function) + : Executable(std::move(hlo_module), std::move(shape_size_function)), ptx_(ptx), thunk_schedule_(std::move(thunk_schedule)), - assignment_(std::move(assignment)), - temp_buffer_offsets_(std::move(temp_buffer_offsets)) {} + assignment_(std::move(assignment)) {} Status GpuExecutable::ExecuteThunks( - se::Stream* main_stream, const BufferAllocations& buffer_allocations, + const ServiceExecutableRunOptions* run_options, + const BufferAllocations& buffer_allocations, bool block_host_until_done, HloExecutionProfile* hlo_execution_profile) { + se::Stream* main_stream = run_options->stream(); + bool do_profile = hlo_execution_profile != nullptr; if (do_profile) { LOG(WARNING) << "PROFILING: profiling is enabled"; } - HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream); + HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, + hlo_module_->entry_computation()); - std::vector> sub_streams; // Stream 0 indicates `main_stream` and substreams start from stream 1. - for (int32 i = 1; i < thunk_schedule_->StreamCount(); ++i) { - auto sub_stream = MakeUnique(main_stream->parent()); - sub_stream->Init(); - sub_streams.emplace_back(std::move(sub_stream)); + std::vector::SmartPtr> sub_streams; + while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { + sub_streams.emplace_back(); + TF_ASSIGN_OR_RETURN( + sub_streams.back(), + run_options->BorrowStream(main_stream->parent()->device_ordinal())); } std::map> thunk_to_finish_event; @@ -160,7 +168,7 @@ Status GpuExecutable::ExecuteThunks( // Make sure kernels are completed before deallocating temporary buffers. // TODO(b/30100571): we could potentially postpone deallocating the temp // buffers until a different computation is executed. - if (!main_stream->BlockHostUntilDone()) { + if (block_host_until_done && !main_stream->BlockHostUntilDone()) { return InternalError("Failed to complete all kernels launched on stream %p", main_stream); } @@ -169,7 +177,7 @@ Status GpuExecutable::ExecuteThunks( } StatusOr GpuExecutable::ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); @@ -188,19 +196,22 @@ StatusOr GpuExecutable::ExecuteOnStream( } } se::StreamExecutor* executor = stream->parent(); - TF_ASSIGN_OR_RETURN(auto buffer_allocations, - buffer_allocations_builder.Build( - *assignment_, *temp_buffer_offsets_, - executor->device_ordinal(), memory_allocator)); + TF_ASSIGN_OR_RETURN( + auto buffer_allocations, + buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), + memory_allocator)); - TF_RETURN_IF_ERROR( - ExecuteThunks(stream, *buffer_allocations, hlo_execution_profile)); + bool block_host_until_done = + !memory_allocator->AllowsAsynchronousDeallocation(); + TF_RETURN_IF_ERROR(ExecuteThunks(run_options, *buffer_allocations, + block_host_until_done, + hlo_execution_profile)); HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); - TF_ASSIGN_OR_RETURN(const BufferAllocation* output_allocation, - assignment_->GetUniqueTopLevelOutputAllocation()); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice output_slice, + assignment_->GetUniqueTopLevelOutputSlice()); se::DeviceMemoryBase output_buffer_address = - buffer_allocations->GetDeviceAddress(output_allocation->index()); + buffer_allocations->GetDeviceAddress(output_slice.index()); if (ShapeUtil::IsTuple(root->shape())) { std::set referred_by_output; @@ -217,21 +228,21 @@ StatusOr GpuExecutable::ExecuteOnStream( // The points-to set of the root is unambiguous so it's known statically // which buffers are in the result. Gather these buffers using the root's // points-to set. - TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElement( + TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElementWithStatus( [&referred_by_output, &buffer_allocations, this]( - const ShapeIndex& /*index*/, bool /*is_leaf*/, + const ShapeIndex& /*index*/, const std::vector& buffers) { // The points to set is unambiguous so the set should be a // singleton. That is, we know exactly which instruction produced // the array at this element. CHECK_EQ(1, buffers.size()); HloInstruction* hlo = buffers[0]->instruction(); - TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, - this->assignment_->GetUniqueAllocation( - hlo, buffers[0]->index())); - CHECK(!allocation->is_entry_computation_parameter()); + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(hlo, buffers[0]->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); referred_by_output.insert( - buffer_allocations->GetDeviceAddress(allocation->index())); + buffer_allocations->GetDeviceAddress(slice.index())); return Status::OK(); })); } @@ -247,10 +258,9 @@ StatusOr GpuExecutable::ExecuteOnStream( } StatusOr> GpuExecutable::ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); // This ExecuteOnStream overload should only be called by the LocalService // which sets has_hybrid_result to true. @@ -273,14 +283,17 @@ StatusOr> GpuExecutable::ExecuteOnStream( i, arguments[param_no]->buffer(/*index=*/{})); } } - se::StreamExecutor* executor = stream->parent(); - TF_ASSIGN_OR_RETURN(auto buffer_allocations, - buffer_allocations_builder.Build( - *assignment_, *temp_buffer_offsets_, - executor->device_ordinal(), memory_allocator)); + se::StreamExecutor* executor = run_options->stream()->parent(); + TF_ASSIGN_OR_RETURN( + auto buffer_allocations, + buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), + memory_allocator)); - TF_RETURN_IF_ERROR( - ExecuteThunks(stream, *buffer_allocations, hlo_execution_profile)); + bool block_host_until_done = + !memory_allocator->AllowsAsynchronousDeallocation(); + TF_RETURN_IF_ERROR(ExecuteThunks(run_options, *buffer_allocations, + block_host_until_done, + hlo_execution_profile)); HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); auto device_ordinal = executor->device_ordinal(); @@ -293,10 +306,10 @@ StatusOr> GpuExecutable::ExecuteOnStream( std::set buffers_in_result; TF_RETURN_IF_ERROR( shaped_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElement( + ->ForEachMutableElementWithStatus( [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( - const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) { - if (is_leaf) { + const ShapeIndex& index, size_t* buffer_entry) { + if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { const std::vector& sources = this->GetRootPointsToSet().element(index); // The points to set is unambiguous so the set should be a @@ -309,13 +322,13 @@ StatusOr> GpuExecutable::ExecuteOnStream( // The source instruction should have a non-parameter buffer // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, - this->assignment_->GetUniqueAllocation( + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice( src_hlo, sources[0]->index())); - CHECK(!allocation->is_entry_computation_parameter()); + CHECK(!slice.allocation()->is_entry_computation_parameter()); perftools::gputools::DeviceMemoryBase src_base = - buffer_allocations->GetDeviceAddress(allocation->index()); + buffer_allocations->GetDeviceAddress(slice.index()); CHECK(!src_base.is_null() || src_base.size() == 0); shaped_buffer->mutable_buffers()->push_back(src_base); *buffer_entry = shaped_buffer->mutable_buffers()->size() - 1; @@ -330,115 +343,8 @@ StatusOr> GpuExecutable::ExecuteOnStream( return std::move(shaped_buffer); } -Status GpuExecutable::ExecuteOnStream( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - // This ExecuteOnStream overload should only be called by the LocalService - // which sets has_hybrid_result to true. - TF_RET_CHECK(module_config().has_hybrid_result()); - - // Every array element in the result of the computation must be unambiguously - // produced by a single instruction. - // This ensures that the buffers inside result_buffer can be assigned without - // conflict to the respective instructions because there is a one-to-one - // correspondence between hlo instructions and array buffers in the result. - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented( - "Points-to set of root instruction is ambiguous or not distinct"); - } - - DCHECK(ShapeUtil::Compatible(result_buffer->shape(), result_shape())); - - BufferAllocations::Builder buffer_allocations_builder; - for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); - ++i) { - const BufferAllocation& allocation = assignment_->GetAllocation(i); - if (allocation.is_entry_computation_parameter()) { - auto param_no = allocation.parameter_number(); - if (ShapeUtil::IsTuple(arguments[param_no]->shape())) { - return Unimplemented("Tuple ShapedBuffer arguments not supported"); - } - buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->buffer(/*index=*/{})); - } - } - - // If two tuple elements point to the same buffer, one of the results in the - // result buffer is considered the canonical location while the other result - // points to it (instead of, say, making a copy of the result). - // buffer_index_to_shape_index maps a buffer index to its canonical location - // in the result buffer. - std::unordered_map - buffer_index_to_shape_index; - - // Register DeviceMemoryBase values in result_buffer to their corresponding - // buffer indices. These buffers will not be allocated in the call to - // BufferAllocationsBuilder::Build. - std::set buffers_in_result; - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElement( - [&buffer_allocations_builder, &buffers_in_result, - &buffer_index_to_shape_index, result_buffer, this]( - const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) { - if (is_leaf) { - const std::vector& sources = - this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction - // produced the array at this element. - CHECK_EQ(1, sources.size()); - auto src_hlo = sources[0]->instruction(); - - VLOG(4) << "Looking at: " << sources[0]; - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation, - this->assignment_->GetUniqueAllocation( - src_hlo, sources[0]->index())); - CHECK(!allocation->is_entry_computation_parameter()); - - auto insert_result = buffer_index_to_shape_index.emplace( - allocation->index(), *buffer_entry); - if (insert_result.second) { - // The points-to set is distinct so this buffer should not - // have been assigned in a previous invocation of this - // lambda. - perftools::gputools::DeviceMemoryBase memory_base = - result_buffer->buffer(index); - CHECK(!memory_base.is_null()); - buffer_allocations_builder.RegisterBuffer( - allocation->index(), memory_base); - buffers_in_result.insert(memory_base); - } else { - // Record the fact that this tuple element is identical to - // some - // prior result. - *buffer_entry = insert_result.first->second; - } - } - return Status::OK(); - })); - - se::StreamExecutor* executor = stream->parent(); - auto device_ordinal = executor->device_ordinal(); - TF_ASSIGN_OR_RETURN( - auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, *temp_buffer_offsets_, - device_ordinal, memory_allocator)); - - TF_RETURN_IF_ERROR( - ExecuteThunks(stream, *buffer_allocations, hlo_execution_profile)); - - return buffer_allocations->TearDown(buffers_in_result, *assignment_); -} - StatusOr GpuExecutable::ExecuteAsyncOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 2343d264dee..e1a55118fc7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -24,12 +24,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -52,9 +50,8 @@ class GpuExecutable : public Executable { GpuExecutable(tensorflow::StringPiece ptx, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, - std::unique_ptr module_config, std::unique_ptr assignment, - std::unique_ptr temp_buffer_offsets); + HloCostAnalysis::ShapeSizeFunction shape_size_function); // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -68,30 +65,30 @@ class GpuExecutable : public Executable { tensorflow::StringPiece ptx() const { return ptx_; } StatusOr ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr> ExecuteOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - Status ExecuteOnStream( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - ShapedBuffer* result_buffer, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; private: - Status ExecuteThunks(perftools::gputools::Stream* stream, + // If `block_host_until_done` is false, execution will not block the host + // until the kernels have completed. This is used as an optimization for + // clients, such as Tensorflow, that use a single stream of execution for + // computations, and allow host-side deallocation from the allocator before + // GPU execution completes. + Status ExecuteThunks(const ServiceExecutableRunOptions* run_options, const BufferAllocations& buffer_allocations, + bool block_host_until_done, HloExecutionProfile* hlo_execution_profile); // Returns the points-to set of the root instruction of the entry @@ -117,10 +114,6 @@ class GpuExecutable : public Executable { // memory for every output/temp buffers. const std::unique_ptr assignment_; - // Owns the mapping from temporary buffers to their offsets in the temp-buffer - // memory block. - const std::unique_ptr temp_buffer_offsets_; - TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 404a53e13b7..d16a1d4ee5b 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -36,26 +36,42 @@ class GpuHloOrdering : public PredecessorHloOrdering { const std::vector& thunk_launch_order); ~GpuHloOrdering() override = default; + // Only the entry computation can possibly be sequentially ordered, and only + // if we've assigned all instructions to a single stream. + const std::vector* SequentialOrder( + const HloComputation& computation) const override { + return &computation == module_->entry_computation() ? entry_sequence_.get() + : nullptr; + } + string ToString() const override { return ToStringHelper("GpuHloOrdering"); } + + private: + std::unique_ptr> entry_sequence_; }; GpuHloOrdering::GpuHloOrdering( const HloModule* module, const StreamAssignment& stream_assignment, const std::vector& thunk_launch_order) : PredecessorHloOrdering(module) { + // The entry computation has a total order when there's only one stream. + if (stream_assignment.StreamCount() == 1) { + entry_sequence_ = + MakeUnique>(thunk_launch_order); + } + // The ordering of instructions for the entry computation is determined by the // total order of thunk launches, and stream assignment. Instructions are // sequential within a stream and concurrent across streams. In addition, the // GpuExecutable adds cross-stream dependency edges to ensure each instruction // waits for its operands before executing. // - // The predecessor map is built incrementally, in thunk launch - // order. We record the instructions already visited per stream in - // 'instructions_per_stream'. This lets us quickly determine the - // same-stream predecessors of each instruction. To capture - // cross-stream dependency edges, we use the predecessor map to - // insert each operand as well as its transitive closure of - // dependencies. + // The predecessor map is built incrementally, in thunk launch order. We + // record the instructions already visited per stream in + // 'instructions_per_stream'. This lets us quickly determine the same-stream + // predecessors of each instruction. To capture cross-stream dependency edges, + // we use the predecessor map to insert each operand as well as its transitive + // closure of dependencies. // Compute the set of all instructions we will want to set reachability on auto predecessor_map = MakeUnique( @@ -98,12 +114,9 @@ GpuHloOrdering::GpuHloOrdering( // dependencies. I.e. the strict predecessors of each subcomputation // instruction is its transitive operands. // - // TODO(toddw): Each subcomputation is actually emitted as a function in - // DFS - // postorder, so we can do better and establish the total order here. We - // don't - // do that yet since it's hard to ensure that the order here is the order - // used + // TODO(toddw): Each subcomputation is actually emitted as a function in DFS + // postorder, so we can do better and establish the total order here. We don't + // do that yet since it's hard to ensure that the order here is the order used // by IrEmitterNested. And mismatched ordering bugs would be hard to find. for (auto& computation : module->computations()) { if (computation.get() != module->entry_computation()) { @@ -113,20 +126,6 @@ GpuHloOrdering::GpuHloOrdering( } } -// Computes a topological launch_order based on depth-first order, visiting -// operands in essentially an arbitrary order. -// -// TODO(b/32006145): Use an ordering that minimizes memory pressure. -tensorflow::Status DFSLaunchOrder( - const HloComputation* computation, - std::vector* launch_order) { - return computation->root_instruction()->Accept( - [launch_order](HloInstruction* hlo) { - launch_order->push_back(hlo); - return tensorflow::Status::OK(); - }); -} - // Computes a topological launch_order that is close to a breadth-first // order. This heuristic works well for graphs where concurrent kernels are // located at the same layer. It can often reduce dependency between concurrent @@ -187,19 +186,24 @@ HloSchedule::HloSchedule() {} /* static */ StatusOr> HloSchedule::Build( - const HloModule& module, const StreamAssignment& stream_assignment) { + const HloModule& module, const StreamAssignment& stream_assignment, + int64 pointer_size) { std::unique_ptr schedule(new HloSchedule); // Initialize thunk_launch_order_, the total order of thunk launches. - const HloComputation* computation = module.entry_computation(); + const HloComputation* entry_computation = module.entry_computation(); if (stream_assignment.StreamCount() == 1) { - // DFS tends to increase buffer reuse, reducing memory usage. All kernels - // are launched on a single stream, so there's no loss of concurrency. - TF_RETURN_IF_ERROR( - DFSLaunchOrder(computation, &schedule->thunk_launch_order_)); + // All kernels are launched on a single stream, so there's no loss of + // concurrency by optimizing for minimal memory usage. + TF_ASSIGN_OR_RETURN( + schedule->thunk_launch_order_, + CreateMemoryMinimizingSequence( + *entry_computation, [pointer_size](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); + })); } else { // BFS tends to increase concurrency, but also increases memory usage. - BFSLaunchOrder(computation, &schedule->thunk_launch_order_); + BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); } schedule->hlo_ordering_ = MakeUnique( diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h index 42d9051aede..773973010a4 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h @@ -39,7 +39,8 @@ class HloSchedule { // Constructs an HloSchedule for the given module, based on the given stream // assignment. static StatusOr> Build( - const HloModule& module, const StreamAssignment& stream_assignment); + const HloModule& module, const StreamAssignment& stream_assignment, + int64 pointer_size); // Returns the total order of thunk launches, represented in terms of HLO // instructions. diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index 70628f11917..118ef18c44b 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" +#include +#include + #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -28,10 +31,27 @@ namespace gpu { class HloScheduleTest : public HloTestBase { protected: - typedef std::vector hlovec; + using HloVec = std::vector; // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); + + static std::unique_ptr BuildHloSchedule( + const HloModule& module, const StreamAssignment& streams) { + return HloSchedule::Build(module, streams, /*pointer_size=*/8) + .ConsumeValueOrDie(); + } + + HloVec RemoveHlo(const HloVec& input, + const std::unordered_set& remove) { + HloVec result(input); + result.erase(std::remove_if(result.begin(), result.end(), + [&remove](const HloInstruction* x) { + return remove.count(x) > 0; + }), + result.end()); + return result; + } }; // Test of a single stream, where data dependencies fully determine the @@ -49,15 +69,17 @@ TEST_F(HloScheduleTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build(dot2)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build(dot2)); - std::unique_ptr streams = AssignStreams(module); + std::unique_ptr streams = AssignStreams(*module); EXPECT_EQ(streams->StreamNumberForHlo(*dot1), streams->StreamNumberForHlo(*dot2)); - auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); - EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, dot1, z, dot2})); + auto schedule = BuildHloSchedule(*module, *streams); + // Remove parameters, which are unordered. + EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), + HloVec({dot1, dot2})); // Parameters x,y,z are mutually unordered, while dot1 and dot2 are // transitively ordered by operands. @@ -107,17 +129,19 @@ TEST_F(HloScheduleTest, SequentialAdd) { HloInstruction* add3 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build(add3)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build(add3)); - std::unique_ptr streams = AssignStreams(module); + std::unique_ptr streams = AssignStreams(*module); EXPECT_EQ(streams->StreamNumberForHlo(*add1), streams->StreamNumberForHlo(*add2)); EXPECT_EQ(streams->StreamNumberForHlo(*add1), streams->StreamNumberForHlo(*add3)); - auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); - EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, add1, z, add2, add3})); + auto schedule = BuildHloSchedule(*module, *streams); + // Remove parameters, which are unordered. + EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), + HloVec({add1, add2, add3})); // Parameters x,y,z are mutually unordered, while add1, add2 and add3 are // transitively ordered by operands. @@ -175,16 +199,18 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build(add)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build(add)); - std::unique_ptr streams = AssignStreams(module); + std::unique_ptr streams = AssignStreams(*module); EXPECT_NE(streams->StreamNumberForHlo(*dot1), streams->StreamNumberForHlo(*dot2)); - auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); - EXPECT_TRUE(schedule->ThunkLaunchOrder() == hlovec({x, y, dot1, dot2, add}) || - schedule->ThunkLaunchOrder() == hlovec({x, y, dot2, dot1, add})); + auto schedule = BuildHloSchedule(*module, *streams); + // Remove parameters, which are unordered. + HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y}); + EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) || + thunk_launch_order == HloVec({dot2, dot1, add})); // Parameters x,y are mutually unordered, while dot1, dot2 and add are // transitively ordered by operands. @@ -228,6 +254,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { // d40 -- layer 4 HloComputation::Builder builder("entry_computation"); std::vector params; + params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); @@ -251,10 +278,10 @@ TEST_F(HloScheduleTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build(d40)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build(d40)); - std::unique_ptr streams = AssignStreams(module); + std::unique_ptr streams = AssignStreams(*module); // The two dots on layer 1 are concurrent. EXPECT_NE(streams->StreamNumberForHlo(*d10), streams->StreamNumberForHlo(*d11)); @@ -271,12 +298,12 @@ TEST_F(HloScheduleTest, LatticeMatMul) { // We don't check the thunk launch order, since there are many valid total // orders, and it's annoying to express. - auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie(); + auto schedule = BuildHloSchedule(*module, *streams); auto order = schedule->ConsumeHloOrdering(); - const hlovec all_params( + const HloVec all_params( {params[0], params[1], params[2], params[3], params[4], params[5]}); - const hlovec all_ops({d00, d10, d11, d20, d21, d22, d30, d31, d40}); + const HloVec all_ops({d00, d10, d11, d20, d21, d22, d30, d31, d40}); // Parameters are mutually unordered, and never execute before ops. for (const HloInstruction* param : all_params) { @@ -366,3 +393,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { } // namespace gpu } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index accc406c76f..1a61eec3537 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -41,7 +41,7 @@ void HloToIrBindings::EmitBasePointersForHlos( // operand HLOs are already bound to avoid rebinding the same HLO. std::set already_bound_for_this_function; auto arg_iter = function->arg_begin(); - for (const auto* io_hlo : io_hlos) { + for (const HloInstruction* io_hlo : io_hlos) { if (!already_bound_for_this_function.count(io_hlo)) { if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); @@ -56,7 +56,7 @@ void HloToIrBindings::EmitBasePointersForHlos( temp_buffer_base_ = &*arg_iter; temp_buffer_base_->setName("temp_buffer"); - for (auto* non_io_hlo : non_io_hlos) { + for (const HloInstruction* non_io_hlo : non_io_hlos) { if (already_bound_for_this_function.count(non_io_hlo)) { continue; } @@ -65,13 +65,13 @@ void HloToIrBindings::EmitBasePointersForHlos( if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) { if (!is_nested_) { // Lookup allocation GetTupleElement operand. - const BufferAllocation* allocation = + const BufferAllocation::Slice slice = buffer_assignment_ - ->GetUniqueTopLevelAllocation(LatestNonGteAncestor(non_io_hlo)) + ->GetUniqueTopLevelSlice(LatestNonGteAncestor(non_io_hlo)) .ConsumeValueOrDie(); // We are not in a nested context, so check non-thread-local allocation. - CHECK(!allocation->is_thread_local()); - int64 offset = temp_buffer_offsets_->GetOffset(allocation->index()); + CHECK(!slice.allocation()->is_thread_local()); + const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); // Emit IR for GetTupleElement instruction and bind to emitted value. llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP( @@ -89,15 +89,15 @@ void HloToIrBindings::EmitBasePointersForHlos( // A non-IO HLO with a buffer is bound to // (1) an alloca if it is thread-local, or // (2) an internal pointer in temp_buffer_base according to its offset. - const BufferAllocation* allocation = - buffer_assignment_->GetUniqueTopLevelAllocation(non_io_hlo) + const BufferAllocation::Slice slice = + buffer_assignment_->GetUniqueTopLevelSlice(non_io_hlo) .ConsumeValueOrDie(); - if (allocation->is_thread_local()) { + if (slice.allocation()->is_thread_local()) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_); BindHloToIrValue(*non_io_hlo, ir_builder_->CreateAlloca(pointee_type)); } else { - int64 offset = temp_buffer_offsets_->GetOffset(allocation->index()); + const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); BindHloToIrValue(*non_io_hlo, ir_builder_->CreateInBoundsGEP( diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 1e3b2684239..5be2150801f 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -22,7 +22,6 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -37,10 +36,8 @@ class HloToIrBindings { public: HloToIrBindings(const HloModule& module, const BufferAssignment* buffer_assignment, - const TempBufferOffsets* temp_buffer_offsets, llvm::IRBuilder<>* ir_builder, bool is_nested) : buffer_assignment_(buffer_assignment), - temp_buffer_offsets_(temp_buffer_offsets), is_nested_(is_nested), ir_builder_(ir_builder), alias_analysis_(module, *buffer_assignment_, @@ -88,8 +85,6 @@ class HloToIrBindings { const BufferAssignment* buffer_assignment_; - const TempBufferOffsets* temp_buffer_offsets_; - const bool is_nested_; llvm::IRBuilder<>* ir_builder_; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc new file mode 100644 index 00000000000..120a3f7fba2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" +#include "tensorflow/core/platform/logging.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +InfeedManager::InfeedManager() + : current_buffer_(nullptr), + host_to_device_executor_(nullptr) {} + +void InfeedManager::Reset() { + tensorflow::mutex_lock l(mu_); + CHECK(!current_buffer_); + for (auto buffer : enqueued_buffer_) { + buffer->Done(); + } + enqueued_buffer_.clear(); +} + +void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) { + tensorflow::mutex_lock l(mu_); + bool was_empty = enqueued_buffer_.empty(); + enqueued_buffer_.push_back(buffer); + if (was_empty) { + // This has the potential to suffer from the notified thread + // immediately trying and failing to acquire mu_, but seems + // preferable to the alternative of notifying outside the lock + // on every enqueue. + cv_.notify_one(); + } +} + +InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { + tensorflow::mutex_lock l(mu_); + while (enqueued_buffer_.empty()) { + cv_.wait(l); + } + CHECK(!current_buffer_); + current_buffer_ = enqueued_buffer_.front(); + enqueued_buffer_.pop_front(); + return current_buffer_; +} + +void InfeedManager::ReleaseCurrentBuffer(se::DeviceMemoryBase* device_memory) { + tensorflow::mutex_lock l(mu_); + CHECK(current_buffer_); + CHECK(device_memory->IsSameAs(*current_buffer_->device_memory())); + current_buffer_->Done(); + current_buffer_ = nullptr; +} + +se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { + if (host_to_device_executor_ == nullptr) { + host_to_device_executor_ = executor; + host_to_device_stream_ = MakeUnique(executor); + host_to_device_stream_->Init(); + } + + if (executor != host_to_device_executor_) { + // The requested executor must be the same as the one for which + // the stream is cached. + return nullptr; + } + + return host_to_device_stream_.get(); +} + +InfeedManager* GetOrCreateInfeedManager() { + static InfeedManager* manager = new InfeedManager; + return manager; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h new file mode 100644 index 00000000000..50d0ce340f3 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header declares classes for the infeed manager and the infeed +// buffer that are used by the GPU runtime to transfer buffers into an +// executing GPU computation, e.g., to feed data into a while loop. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// TODO(b/30467474) Once GPU infeed implementation settles, consider +// folding back the cpu and gpu infeed implementations into a generic +// one if possible. +// +// Current limitations: +// * Does not handle multiple devices/replicas. +// +// * Buffer space on GPU is allocated on every infeed enqueue request, +// and it does not handle the case when it runs out of +// memory. Potential solution is to pre-allocate a fixed amount of +// memory and block when that memory is full. + +// Defines an infeed buffer that is passed to the runtime by +// the client. The client manages the memory of the buffer. +class InfeedBuffer { + public: + InfeedBuffer(perftools::gputools::StreamExecutor* executor, int64 length) + : executor_(executor), length_(length) { + device_memory_ = executor_->AllocateArray(length); + CHECK(!device_memory_.is_null()); + } + + ~InfeedBuffer() { executor_->Deallocate(&device_memory_); } + + int64 length() const { return length_; } + + // Callback to signal that this buffer is consumed. This helps the + // client to manage memory for the infeed buffers. + void Done() { delete this; } + + perftools::gputools::DeviceMemoryBase* device_memory() { + return &device_memory_; + } + + private: + perftools::gputools::StreamExecutor* executor_; // Not owned. + const int64 length_; + perftools::gputools::DeviceMemoryBase device_memory_; +}; + +// Client-side class used to enqueue infeed buffers. +class InfeedManager { + public: + InfeedManager(); + + // Calls the completion callback for any enqueued buffers that have + // not been dequeued by the runtime, and empties the infeed + // queue. Reset may not be called while a runtime computation is + // processing a dequeued buffer. The only safe way to ensure this + // condition is to call Reset when no computation is taking place. + void Reset(); + + // Adds buffer to the infeed queue. buffer->Done will be called when + // the buffer will no longer be accessed by the InfeedManager, + // either as a result of a call to Reset or because the runtime has + // dequeued and used the buffer. + void EnqueueBuffer(InfeedBuffer* buffer); + + // Blocks until the infeed queue is non-empty, then returns the + // buffer at the head of the queue. Sets the current buffer to be + // the returned buffer. It is an error to call BlockingDequeueBuffer + // if there is an unreleased current buffer, i.e., + // ReleaseCurrentBuffer must be called between calls to + // BlockingDequeueBuffer. + InfeedBuffer* BlockingDequeueBuffer(); + + // Releases the current buffer, which is the last buffer returned by + // BlockingDequeueBuffer and not yet released. device_memory must + // match that of the current buffer. + void ReleaseCurrentBuffer( + perftools::gputools::DeviceMemoryBase* device_memory); + + // Returns a cached stream associated with an executor. Allocates a + // new stream on the first invocation. On subsequent invocations, if + // the cached executor is not the same as the requested executor, + // returns null. + perftools::gputools::Stream* GetStream( + perftools::gputools::StreamExecutor* executor); + + private: + tensorflow::mutex mu_; + // Condition variable that is signaled every time a buffer is + // enqueued to an empty queue. + tensorflow::condition_variable cv_; + // InfeedBuffer* queue contents are not owned, but buffer->Done must + // be called when the buffer is no longer needed by the runtime. + std::deque enqueued_buffer_; + // If non-NULL, the buffer that is currently being processed by the + // runtime. Not owned. + InfeedBuffer* current_buffer_; + // Cached host to device stream for queuing infeed data. + std::unique_ptr host_to_device_stream_; + // Executor that the host_to_device_stream belongs to. Not owned. + perftools::gputools::StreamExecutor* host_to_device_executor_; +}; + +// Singleton creator-or-accessor: Returns the GPU infeed manager. +InfeedManager* GetOrCreateInfeedManager(); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc new file mode 100644 index 00000000000..6f144c7273e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +InfeedThunk::InfeedThunk(const BufferAllocation::Slice& destination_buffer, + uint64 mem_size, const HloInstruction* hlo_instruction) + : Thunk(Kind::kInfeed, hlo_instruction), + destination_buffer_(destination_buffer), + mem_size_(mem_size) {} + +tensorflow::Status InfeedThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { + VLOG(2) << "Infeeding to GPU "; + perftools::gputools::DeviceMemoryBase destination_data = + buffer_allocations.GetDeviceAddress(destination_buffer_); + + InfeedManager* infeed_manager = GetOrCreateInfeedManager(); + InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); + CHECK_EQ(buffer->length(), mem_size_); + stream->ThenMemcpy(&destination_data, *(buffer->device_memory()), + buffer->length()); + if (!stream->BlockHostUntilDone()) { + return InternalError("Failed to complete data transfer on stream %p", + stream); + } + // Since Infeeds are totally ordered, no other infeed should sneak + // in and we should be able to release the same buffer we dequeued. + infeed_manager->ReleaseCurrentBuffer(buffer->device_memory()); + return tensorflow::Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h new file mode 100644 index 00000000000..0a808186c21 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +// A thunk that infeeds data. Data must be already resident on the +// device. This thunk performs an intra-device copy from that location +// to the buffer allocated for the infeed op. +class InfeedThunk : public Thunk { + public: + // Constructs a InfeedThunk that copies data from the on-device + // infeed queue to the device buffer + // `destination_buffer`. `mem_size` is the size of the data in + // bytes. + InfeedThunk(const BufferAllocation::Slice& destination_buffer, + uint64 mem_size, const HloInstruction* hlo_instruction); + + InfeedThunk(const InfeedThunk&) = delete; + InfeedThunk& operator=(const InfeedThunk&) = delete; + + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const BufferAllocation::Slice destination_buffer_; + const uint64 mem_size_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 91fd7ae77a9..a36dcbbd2fa 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -46,6 +46,16 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Output fusion is not currently supported on GPUs. + if (producer->opcode() == HloOpcode::kFusion) { + return false; + } + + // RNG operations are not currently parallel-friendly on GPU. + if (producer->opcode() == HloOpcode::kRng) { + return false; + } + // Do not fuse to-vector reduction into other consumers. They should be // unfused or the root of a kInput fusion. if (IsReductionToVector(*producer)) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index 21f3b542a27..bb2990e6dfc 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -25,7 +25,7 @@ namespace gpu { class GpuInstructionFusion : public InstructionFusion { public: explicit GpuInstructionFusion(bool may_duplicate) - : InstructionFusion(may_duplicate) {} + : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index c58af04bad0..896f6ea8425 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace gpu { @@ -32,7 +31,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -49,7 +48,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -89,7 +88,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) @@ -108,7 +107,7 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) { HloInstruction::CreateGetTupleElement(data_shape, param, 1)); builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, gte0, gte1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) @@ -124,3 +123,7 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) { } // namespace gpu } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index e141179ba17..a77d3d7065c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -59,6 +60,11 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { + // We can only do this if the HLO is unnested. + if (hlo.parent() != hlo.GetModule()->entry_computation()) { + return false; + } + // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); @@ -85,15 +91,19 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { } bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { + // We can only do this if the HLO is unnested. + if (hlo.parent() != hlo.GetModule()->entry_computation()) { + return false; + } + // Forward convolution. if (hlo.opcode() == HloOpcode::kConvolution) { const ConvolutionDimensionNumbers& dnums = hlo.convolution_dimension_numbers(); - // Only 2D convolutions are implemented. - // TODO(b/32873825): add support for 3D convolutions using CuDNN. - if (dnums.spatial_dimensions_size() != 2) { + if (dnums.spatial_dimensions_size() > 3) { return false; } + // CuDNN does not accept zero-element arguments if (ShapeUtil::HasZeroElements(hlo.operand(0)->shape()) || ShapeUtil::HasZeroElements(hlo.operand(1)->shape())) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 4d3e9b10b2e..e8c68a6ef72 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -25,16 +25,7 @@ limitations under the License. namespace xla { namespace gpu { -const int64 kWarpSize = 32; - -// Precondition: "hlo" is an operand of a Dot instruction. -// -// Returns whether "hlo" is foldable to its user. -bool IsOperandFoldableToDot(const HloInstruction& hlo); - -// Returns true if GpuCompiler can fold any operands of "dot" into "dot" for -// better performance. -bool CanFoldOperandsIntoDot(const HloInstruction& dot); +constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. bool ImplementedAsGemm(const HloInstruction& hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index cad2c903ff3..7d5b6ed5cfa 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -54,11 +54,12 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, : ir_emitter_context_(ir_emitter_context), ir_builder_(ir_emitter_context->llvm_module()->getContext()), bindings_(ir_emitter_context->hlo_module(), - &ir_emitter_context->buffer_assignment(), - &ir_emitter_context->temp_buffer_offsets(), &ir_builder_, + &ir_emitter_context->buffer_assignment(), &ir_builder_, is_nested), hlo_module_config_(hlo_module_config) { - ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(hlo_module_config)); + ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( + /*fast_math_enabled=*/hlo_module_config.debug_options() + .xla_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { @@ -99,7 +100,7 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { // sometimes, e.g., when it's operand is a constant or a bitcast of a // constant. if (bindings_.BoundToIrValue(*operand)) { - bindings_.BindHloToIrValue(*bitcast, bindings_.GetBasePointer(*operand)); + bindings_.BindHloToIrValue(*bitcast, GetBasePointer(*operand)); } return Status::OK(); } @@ -400,7 +401,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot, llvm::Type* accum_type = target_array.GetElementLlvmType(); llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry( accum_type, // The pointee type of the alloca instruction. - "accum_address", // The name of the alloca instuction. + "accum_address", // The name of the alloca instruction. &ir_builder_); // Initialize the accumulator in the preheader to zero. @@ -431,12 +432,12 @@ Status IrEmitter::HandleDot(HloInstruction* dot, // and lhs indexes with the reduction dimensions removed. The terms from the // rhs index are the lower dimensions in the index so we add them first. llvm_ir::IrArray::Index target_index; - for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { + for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); } } - for (int dimension = 0; dimension < rhs_index.size(); ++dimension) { + for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { target_index.push_back(rhs_index[dimension]); } @@ -514,7 +515,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (int64 i = 0; i < input_index.size(); ++i) { + for (size_t i = 0; i < input_index.size(); ++i) { if (input_index[i] == nullptr) { input_index[i] = *it++; } @@ -550,14 +551,12 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator()); } -Status IrEmitter::HandleCall( - HloInstruction* call, tensorflow::gtl::ArraySlice operands, - HloComputation* computation) { +Status IrEmitter::HandleCall(HloInstruction* call) { std::vector operand_addresses; - for (HloInstruction* operand : operands) { + for (HloInstruction* operand : call->operands()) { operand_addresses.push_back(GetBasePointer(*operand)); } - return EmitCallToNestedComputation(*computation, operand_addresses, + return EmitCallToNestedComputation(*call->to_apply(), operand_addresses, GetBasePointer(*call)); } @@ -615,7 +614,7 @@ llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( llvm_ir::IrArray::Index index = loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); // Verify every dimension except the reduction dimension was set in the index. - for (int dimension = 0; dimension < index.size(); ++dimension) { + for (size_t dimension = 0; dimension < index.size(); ++dimension) { if (dimension == reduction_dimension) { DCHECK_EQ(nullptr, index[dimension]); } else { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index c8ca5c41b08..607a366ac67 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -101,9 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* on_true, HloInstruction* on_false) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call, - tensorflow::gtl::ArraySlice operands, - HloComputation* computation) override; + Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) override; @@ -127,12 +125,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } - // A convenient helper for calling BufferAssignment::GetAllocationIndex. - BufferAllocation::Index GetAllocationIndex(const HloInstruction& hlo) const { + // A convenient helper for calling BufferAssignment::GetUniqueTopLevelSlice. + BufferAllocation::Slice GetAllocationSlice(const HloInstruction& hlo) const { return ir_emitter_context_->buffer_assignment() - .GetUniqueTopLevelAllocation(&hlo) - .ConsumeValueOrDie() - ->index(); + .GetUniqueTopLevelSlice(&hlo) + .ConsumeValueOrDie(); } // Emit a singlethreaded or multithreaded loop that computes every element in @@ -250,8 +247,8 @@ class IrEmitterUnnested : public IrEmitter { Status HandleTuple( HloInstruction* tuple, tensorflow::gtl::ArraySlice operands) override; - Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, HloComputation* body) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleRng(HloInstruction* random, RandomDistribution distribution) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, @@ -345,6 +342,10 @@ class IrEmitterUnnested : public IrEmitter { // Returns a CopyThunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildCopyThunk(const HloInstruction* inst); + // Returns an InfeedThunk that performs device-to-device memcpy to implement + // `inst`. + std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); + // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index b204d9625c1..454c3f9ab2d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -19,7 +19,6 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -33,12 +32,10 @@ class IrEmitterContext { public: IrEmitterContext(const HloModule* hlo_module, const BufferAssignment* buffer_assignment, - const TempBufferOffsets* temp_buffer_offsets, const perftools::gputools::DeviceDescription* device_desc, llvm::Module* llvm_module) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), - temp_buffer_offsets_(temp_buffer_offsets), device_desc_(device_desc), llvm_module_(llvm_module) {} // Disallow copy and assign. @@ -50,9 +47,6 @@ class IrEmitterContext { const BufferAssignment& buffer_assignment() const { return *buffer_assignment_; } - const TempBufferOffsets& temp_buffer_offsets() const { - return *temp_buffer_offsets_; - } const perftools::gputools::DeviceDescription& device_description() const { return *device_desc_; } @@ -62,7 +56,6 @@ class IrEmitterContext { private: const HloModule* hlo_module_; const BufferAssignment* buffer_assignment_; - const TempBufferOffsets* temp_buffer_offsets_; const perftools::gputools::DeviceDescription* device_desc_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index c107f9cbbe2..5fa2bfdd7e4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" @@ -190,14 +191,15 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( // The last argument is a pointer to the temporary buffer memory block. // We know that it doesn't alias any of the escaped arguments (the inputs + // the result). We also know how many bytes can be dereferenced in it. - const llvm::Argument& temp_buffer = kernel->getArgumentList().back(); - int64 temp_buffer_size = - ir_emitter_context_->temp_buffer_offsets().TotalSizeInBytes(); + const llvm::Argument& temp_buffer = *std::prev(kernel->arg_end()); int64 temp_buffer_arg_no = temp_buffer.getArgNo(); - if (temp_buffer_size > 0) { - kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, temp_buffer_size); + int64 temp_allocation_total_size = + ir_emitter_context_->buffer_assignment().temp_allocation_total_size(); + if (temp_allocation_total_size != 0) { + kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, + temp_allocation_total_size); } - kernel->setDoesNotAlias(temp_buffer_arg_no + 1); + kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias); // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX // treats it as a CUDA kernel. @@ -249,6 +251,46 @@ Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution, rhs_instruction, window); } +namespace { + +// Returns the first non-GetTupleElement ancestor instruction of 'hlo'. +// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the +// (possibly nested) tuple indices used on the path from ancestor to 'hlo'. +const HloInstruction* LatestNonGteAncestorAndIndex(const HloInstruction* hlo, + ShapeIndex* index) { + if (hlo->opcode() == HloOpcode::kGetTupleElement) { + const auto* operand = LatestNonGteAncestorAndIndex(hlo->operand(0), index); + index->push_back(hlo->tuple_index()); + return operand; + } + return hlo; +} + +// Checks if we can emit code for DynamicUpdateSlice to update data in-place. +// Returns true if operand 0 of DynamicUpdateSlice and its output buffer +// share the same buffer allocation. +// Returns false otherwise. +bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment, + HloInstruction* fusion) { + CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); + HloInstruction* fused_root = fusion->fused_expression_root(); + if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice) { + return false; + } + // Walk DynamicUpdateSlice operand(0) to fused parameter and get its + // associated operand. See if it shares an allocation with this operand. + ShapeIndex index; + auto* fusion_operand = + LatestNonGteAncestorAndIndex(fused_root->operand(0), &index); + if (fusion_operand->opcode() != HloOpcode::kParameter) { + return false; + } + auto* operand = fusion->operand(fusion_operand->parameter_number()); + return assignment.SharesSliceAtIndex(fusion, {}, operand, index); +} + +} // namespace + Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); // HandleFusion specializes reduction from a multi-dimensional array to a 1D @@ -277,20 +319,19 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); Shape input_shape = root->operand(0)->shape(); - // EmitRedutionToVector requires the input shape to have a layout, but + // EmitReductionToVector requires the input shape to have a layout, but // fused instructions don't have one. So we determine its layout from // the fusion's operands. The choice of the layout only affects // performance but not correctness. auto choose_input_layout = []( tensorflow::gtl::ArraySlice operands, - Shape* input_shape) { + Shape* input_shape) -> Status { // Prefer the layout of an operand whose shape is compatible with // input_shape. for (const HloInstruction* operand : operands) { if (ShapeUtil::Compatible(*input_shape, operand->shape())) { - LayoutUtil::CopyLayoutBetweenShapes(operand->shape(), - input_shape); - return; + return LayoutUtil::CopyLayoutBetweenShapes(operand->shape(), + input_shape); } } // If no operand has a compatible shape, prefer an operand that has @@ -301,24 +342,114 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // Do not use CopyLayoutBetweenShapes because input_shape and // operand->shape() may be incompatible. *input_shape->mutable_layout() = operand->shape().layout(); - return; + return Status::OK(); } } // When all the above fails, which is rare, set the default layout. LayoutUtil::SetToDefaultLayout(input_shape); + return Status::OK(); }; - choose_input_layout(fusion->operands(), &input_shape); + TF_RETURN_IF_ERROR( + choose_input_layout(fusion->operands(), &input_shape)); return EmitReductionToVector( root, input_shape, fused_emitter.GetGenerator(root->operand(0)), fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), root->to_apply()); - break; } default: LOG(FATAL) << "Bad opcode for input fusion: " << fusion->fused_expression_root()->opcode(); } + } else if (HloInstruction::FusionKind::kLoop == fusion->fusion_kind() && + root->opcode() == HloOpcode::kDynamicUpdateSlice && + CanUpdateDynamicSliceInPlace( + ir_emitter_context_->buffer_assignment(), fusion)) { + // Loop fusion instruction with DynamicUpdateSlice as fused root. + // DynamicUpdateSlice's operand(0) and 'fusion' output share the same + // BufferAllocation::Slice, so it is safe to emit code to update the slice + // 'in-place'. This avoids copying data outside of the slice update region. + + // Set up kernel thunk and fused ir emitter. + thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); + std::vector parameter_arrays; + for (HloInstruction* operand : fusion->operands()) { + parameter_arrays.push_back(GetIrArray(*operand)); + } + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &ir_builder_, GetNestedComputer()); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); + + // Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0. + auto* fusion_operand = LatestNonGteAncestor(root->operand(0)); + CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode()); + + // Operand(0) the input array which shares an allocation with the output. + const auto* input = root->operand(0); + llvm::Value* input_base_ptr = fused_emitter.GetIrValueForGTE(input); + // Operand(1) 'update' is slice with which to update input at operand(0). + const auto* update = root->operand(1); + Shape update_shape = update->shape(); + TF_RETURN_IF_ERROR( + LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); + // Operand(2) the dynamic slice indices at which to write 'update'. + const auto* start_indices = root->operand(2); + + // Create element generators for 'update' and 'start_indices'. + llvm_ir::ElementGenerator element_generator = + fused_emitter.GetGenerator(update); + llvm_ir::ElementGenerator start_generator = + fused_emitter.GetGenerator(start_indices); + + // Create loop body emitter which emits code to do the following: + // *) Read dynamic slice start indices into 'start_index'. + // *) Map requested 'index' and slice 'start_index' to input/output shape + // as 'output_index'. + // *) Reads value from 'update' element generator. + // *) Writes value to input/output array at 'output_index'. + auto loop_body_emitter = + [=](const llvm_ir::IrArray::Index& index) -> Status { + // Emit IR to read dynamic start indices from hlo->operand(2). + const int64 rank = ShapeUtil::Rank(input->shape()); + llvm_ir::IrArray::Index start_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm_ir::IrArray::Index dim_index({ir_builder_.getInt64(i)}); + TF_ASSIGN_OR_RETURN(start_index[i], start_generator(dim_index)); + } + + // Calculate 'output_index' at which to write value from update. + llvm_ir::IrArray::Index output_index(rank); + for (int64 i = 0; i < rank; ++i) { + // Emit IR which computes: + // output_index = (start_index + index) % dim_size + llvm::Value* dim_size = llvm::ConstantInt::get( + index[i]->getType(), input->shape().dimensions(i)); + llvm::Value* start_index0 = ir_builder_.CreateZExtOrBitCast( + start_index[i], index[i]->getType()); + output_index[i] = ir_builder_.CreateURem( + ir_builder_.CreateAdd(start_index0, index[i]), dim_size); + } + + // Read value from 'update'. + TF_ASSIGN_OR_RETURN(llvm::Value * input_value, element_generator(index)); + // Write value to output array. + llvm_ir::IrArray(input_base_ptr, input->shape()) + .EmitWriteArrayElement(output_index, input_value, &ir_builder_); + return Status::OK(); + }; + + // Create loop which iterates over 'update' shape. + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + update_shape, ir_emitter_context_->device_description()); + CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); + UpdateLaunchDimensions(launch_dimensions, + static_cast(LastThunk()), + ir_emitter_context_->llvm_module()); + return ParallelLoopEmitter(loop_body_emitter, update_shape, + launch_dimensions, &ir_builder_) + .EmitLoop(); } if (ImplementedAsGemm(*fusion)) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); @@ -1195,12 +1326,12 @@ Status IrEmitterUnnested::HandleTuple( // buffer -- their contents are stored in code. In that case, we fall back // to emitting kernels which have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { - std::vector tuple_element_buffers; + std::vector tuple_element_buffers; for (const HloInstruction* tuple_element : operands) { - tuple_element_buffers.push_back(GetAllocationIndex(*tuple_element)); + tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } thunk_sequence_->emplace_back(MakeUnique( - tuple_element_buffers, GetAllocationIndex(*tuple), tuple)); + tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } // If `inst` is a nested thunk that can be disassembled from the result tuple, @@ -1412,10 +1543,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( .EmitLoop(); } -Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while, - HloInstruction* init, - HloComputation* condition, - HloComputation* body) { +Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { + HloComputation* condition = xla_while->while_condition(); TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; @@ -1451,6 +1580,11 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select, return IrEmitter::HandleSelect(select, pred, on_true, on_false); } +Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { + thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); + return Status::OK(); +} + llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( const HloInstruction& hlo, std::vector* io_hlos) { const BufferAssignment& buffer_assignment = @@ -1463,8 +1597,9 @@ llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( for (const HloInstruction* operand : hlo.operands()) { const HloInstruction* to_lookup = LatestNonGteAncestor(operand); if (buffer_assignment.HasTopLevelAllocation(to_lookup) && - buffer_assignment.GetUniqueTopLevelAllocation(to_lookup) + buffer_assignment.GetUniqueTopLevelSlice(to_lookup) .ConsumeValueOrDie() + .allocation() ->IsInputOrOutput()) { io_hlos->push_back(operand); } else { @@ -1474,8 +1609,9 @@ llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( CHECK_NE(HloOpcode::kGetTupleElement, hlo.opcode()); if (buffer_assignment.HasTopLevelAllocation(&hlo) && - buffer_assignment.GetUniqueTopLevelAllocation(&hlo) + buffer_assignment.GetUniqueTopLevelSlice(&hlo) .ConsumeValueOrDie() + .allocation() ->IsInputOrOutput()) { io_hlos->push_back(&hlo); } else { @@ -1496,9 +1632,10 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( EmitBasePointersForHloAndItsOperands(*inst, &io_hlos); // Compute the input buffer indices. - std::vector io_buffers; + std::vector io_buffers; + io_buffers.reserve(io_hlos.size()); for (const HloInstruction* io_hlo : io_hlos) { - io_buffers.push_back(GetAllocationIndex(*LatestNonGteAncestor(io_hlo))); + io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo))); } // Create a KernelThunk that launches the kernel that implements "inst". @@ -1512,10 +1649,21 @@ std::unique_ptr IrEmitterUnnested::BuildCopyThunk( CHECK_EQ(HloOpcode::kConstant, operand->opcode()); return MakeUnique( /*source_address=*/LiteralUtil::InternalData(operand->literal()), - /*destination_buffer=*/GetAllocationIndex(*inst), - /*mem_size=*/llvm_ir::ByteSizeOf( - operand->shape(), - ir_emitter_context_->llvm_module()->getDataLayout()), + /*destination_buffer=*/GetAllocationSlice(*inst), + /*mem_size=*/ + llvm_ir::ByteSizeOf(operand->shape(), + ir_emitter_context_->llvm_module()->getDataLayout()), + inst); +} + +std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( + const HloInstruction* inst) { + CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); + return MakeUnique( + /*destination_buffer=*/GetAllocationSlice(*inst), + /*mem_size=*/ + llvm_ir::ByteSizeOf(inst->shape(), + ir_emitter_context_->llvm_module()->getDataLayout()), inst); } @@ -1525,9 +1673,9 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); return MakeUnique( - GetAllocationIndex(*lhs), // The buffer assigned to LHS. - GetAllocationIndex(*rhs), // The buffer assigned to RHS. - GetAllocationIndex(*inst), // The output buffer. + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. lhs->shape(), // The shape of LHS. rhs->shape(), // The shape of RHS. inst->shape(), // The shape of the output. @@ -1549,9 +1697,9 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( inst->operand(rhs_parameter->parameter_number()); return MakeUnique( - GetAllocationIndex(*lhs), // The buffer assigned to LHS. - GetAllocationIndex(*rhs), // The buffer assigned to RHS. - GetAllocationIndex(*inst), // The output buffer. + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. lhs->shape(), // The shape of LHS. rhs->shape(), // The shape of RHS. inst->shape(), // The shape of the output. @@ -1571,9 +1719,9 @@ std::unique_ptr IrEmitterUnnested::BuildConvolutionThunk( // Forward covolution. return MakeUnique( ConvolutionThunk::ConvolutionKind::kForward, - /*input_buffer=*/GetAllocationIndex(*lhs), - /*filter_buffer=*/GetAllocationIndex(*rhs), - /*output_buffer=*/GetAllocationIndex(*inst), + /*input_buffer=*/GetAllocationSlice(*lhs), + /*filter_buffer=*/GetAllocationSlice(*rhs), + /*output_buffer=*/GetAllocationSlice(*inst), /*input_shape=*/lhs->shape(), /*filter_shape=*/rhs->shape(), /*output_shape=*/inst->shape(), inst->window(), @@ -1587,9 +1735,9 @@ std::unique_ptr IrEmitterUnnested::BuildConvolutionThunk( case HloInstruction::FusionKind::kConvBackwardFilter: return MakeUnique( ConvolutionThunk::ConvolutionKind::kBackwardFilter, - /*input_buffer=*/GetAllocationIndex(*lhs), - /*filter_buffer=*/GetAllocationIndex(*inst), - /*output_buffer=*/GetAllocationIndex(*rhs), + /*input_buffer=*/GetAllocationSlice(*lhs), + /*filter_buffer=*/GetAllocationSlice(*inst), + /*output_buffer=*/GetAllocationSlice(*rhs), /*input_shape=*/lhs->shape(), /*filter_shape=*/inst->shape(), /*output_shape=*/rhs->shape(), inst->window(), @@ -1597,9 +1745,9 @@ std::unique_ptr IrEmitterUnnested::BuildConvolutionThunk( case HloInstruction::FusionKind::kConvBackwardInput: return MakeUnique( ConvolutionThunk::ConvolutionKind::kBackwardInput, - /*input_buffer=*/GetAllocationIndex(*inst), - /*filter_buffer=*/GetAllocationIndex(*rhs), - /*output_buffer=*/GetAllocationIndex(*lhs), + /*input_buffer=*/GetAllocationSlice(*inst), + /*filter_buffer=*/GetAllocationSlice(*rhs), + /*output_buffer=*/GetAllocationSlice(*lhs), /*input_shape=*/inst->shape(), /*filter_shape=*/rhs->shape(), /*output_shape=*/lhs->shape(), inst->window(), @@ -1651,26 +1799,23 @@ namespace { Status CheckWhileBuffersShareAllocation( const HloInstruction* xla_while, const BufferAssignment& buffer_assignment) { - return ShapeUtil::ForEachSubshape( + return ShapeUtil::ForEachSubshapeWithStatus( xla_while->shape(), [&buffer_assignment, &xla_while](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { auto check = [&buffer_assignment](const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index) -> Status { - BufferAllocation::Index index_a = - buffer_assignment.GetUniqueAllocation(a, index) - .ConsumeValueOrDie() - ->index(); - BufferAllocation::Index index_b = - buffer_assignment.GetUniqueAllocation(b, index) - .ConsumeValueOrDie() - ->index(); - if (index_a != index_b) { + const BufferAllocation::Slice slice_a = + buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie(); + const BufferAllocation::Slice slice_b = + buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie(); + if (slice_a != slice_b) { return InternalError( - "instruction %s does not share allocation with " - "instruction %s ", - a->ToString().c_str(), b->ToString().c_str()); + "instruction %s %s does not share allocation with " + "instruction %s %s", + a->ToString().c_str(), slice_a.ToString().c_str(), + b->ToString().c_str(), slice_b.ToString().c_str()); } return Status::OK(); }; @@ -1710,7 +1855,7 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); return MakeUnique( - GetAllocationIndex(*condition->root_instruction()), // cond result + GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition.ConsumeThunkSequence(), ir_emitter_body.ConsumeThunkSequence(), hlo); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 14760fe92cc..69399e36c4c 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -28,11 +28,9 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { -using Index = BufferAllocation::Index; - -KernelThunk::KernelThunk(tensorflow::gtl::ArraySlice io_buffers, - const string& kernel_name, - const HloInstruction* hlo_instruction) +KernelThunk::KernelThunk( + tensorflow::gtl::ArraySlice io_buffers, + const string& kernel_name, const HloInstruction* hlo_instruction) : Thunk(Kind::kKernel, hlo_instruction), io_buffers_(io_buffers.begin(), io_buffers.end()), kernel_name_(kernel_name) {} @@ -62,20 +60,25 @@ tensorflow::Status KernelThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); - se::KernelBase kernel(executor); LaunchDimensions launch_dimensions; + const se::KernelBase* kernel = nullptr; { tensorflow::mutex_lock lock(mutex_); - if (!executor->GetKernel(*loader_spec_, &kernel)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + auto it = kernel_cache_.find(executor); + if (kernel_cache_.end() == it) { + it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; + if (!executor->GetKernel(*loader_spec_, &it->second)) { + return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + } } launch_dimensions = launch_dimensions_; + kernel = &it->second; } // Launch the kernel with potentially multiple blocks and threads. static constexpr int kKernelArgsLimit = 1024; auto kernel_args = MakeUnique>(); - for (BufferAllocation::Index io_buffer : io_buffers_) { + for (const BufferAllocation::Slice io_buffer : io_buffers_) { kernel_args->add_device_memory_argument( buffer_allocations.GetDeviceAddress(io_buffer)); } @@ -83,7 +86,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream( buffer_allocations.GetTempBufferBase()); if (!stream->parent()->Launch( stream, se::ThreadDim(launch_dimensions.threads_per_block()), - se::BlockDim(launch_dimensions.block_count()), kernel, + se::BlockDim(launch_dimensions.block_count()), *kernel, *kernel_args)) { return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 901825873ae..350b5aaf360 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -46,7 +46,7 @@ class KernelThunk : public Thunk { // Constructs a thunk for the given kernel. // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. - KernelThunk(tensorflow::gtl::ArraySlice io_buffers, + KernelThunk(tensorflow::gtl::ArraySlice io_buffers, const string& kernel_name, const HloInstruction* hlo_instruction); KernelThunk(const KernelThunk&) = delete; KernelThunk& operator=(const KernelThunk&) = delete; @@ -64,7 +64,7 @@ class KernelThunk : public Thunk { private: // The indices of the input/output buffers. - const std::vector io_buffers_; + const std::vector io_buffers_; // Entry kernel name for the computation. const string kernel_name_; @@ -78,6 +78,11 @@ class KernelThunk : public Thunk { mutable tensorflow::mutex mutex_; std::unique_ptr loader_spec_ GUARDED_BY(mutex_); + + // Loaded kernels for each `StreamExecutor` + std::unordered_map + kernel_cache_ GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc index ff6cfd94484..66cc7b3e40d 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc @@ -79,26 +79,37 @@ Status GpuLayoutAssignment::AddBackendConstraints( // calls after we switch to cuDNN v5. const ConvolutionDimensionNumbers& dimension_numbers = instruction->convolution_dimension_numbers(); + std::vector input_layout; + for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0; + --i) { + input_layout.push_back(dimension_numbers.spatial_dimensions(i)); + } + input_layout.push_back(dimension_numbers.feature_dimension()); + input_layout.push_back(dimension_numbers.batch_dimension()); Shape input_shape(input->shape()); - *input_shape.mutable_layout() = - LayoutUtil::MakeLayout({dimension_numbers.spatial_dimensions(1), - dimension_numbers.spatial_dimensions(0), - dimension_numbers.feature_dimension(), - dimension_numbers.batch_dimension()}); + *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); + std::vector filter_layout; + for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; + i >= 0; --i) { + filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); + } + filter_layout.push_back( + dimension_numbers.kernel_input_feature_dimension()); + filter_layout.push_back( + dimension_numbers.kernel_output_feature_dimension()); Shape filter_shape(filter->shape()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout( - {dimension_numbers.kernel_spatial_dimensions(1), - dimension_numbers.kernel_spatial_dimensions(0), - dimension_numbers.kernel_input_feature_dimension(), - dimension_numbers.kernel_output_feature_dimension()}); + *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); + std::vector output_layout; + for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0; + --i) { + output_layout.push_back(dimension_numbers.spatial_dimensions(i)); + } + output_layout.push_back(dimension_numbers.feature_dimension()); + output_layout.push_back(dimension_numbers.batch_dimension()); Shape output_shape(output->shape()); - *output_shape.mutable_layout() = - LayoutUtil::MakeLayout({dimension_numbers.spatial_dimensions(1), - dimension_numbers.spatial_dimensions(0), - dimension_numbers.feature_dimension(), - dimension_numbers.batch_dimension()}); + *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); // Set layouts of the instructions' shapes. if (instruction->opcode() == HloOpcode::kConvolution) { diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc index 692ec8147d3..fa258b6e567 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc @@ -55,9 +55,9 @@ TEST_F(LayoutAssignmentTest, Elementwise) { HloInstruction::CreateParameter(1, ashape, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* computation = - module.AddEntryComputation(builder.Build(add)); + module->AddEntryComputation(builder.Build(add)); ComputationLayout computation_layout( computation->ComputeProgramShape()); @@ -69,7 +69,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { ShapeLayout(result_shape_with_layout); GpuLayoutAssignment layout_assignment(&computation_layout); - EXPECT_TRUE(layout_assignment.Run(&module).ValueOrDie()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(), @@ -83,3 +83,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { } // namespace } // namespace gpu } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 12ea573a9c1..e03571a9672 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "external/llvm/include/llvm/ADT/STLExtras.h" @@ -53,6 +54,7 @@ limitations under the License. #include "external/llvm/include/llvm/Transforms/IPO/AlwaysInliner.h" #include "external/llvm/include/llvm/Transforms/IPO/PassManagerBuilder.h" +#include "external/llvm/include/llvm/Transforms/IPO/Internalize.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" @@ -68,29 +70,64 @@ namespace { // Default inline threshold value to use in llvm. const int kDefaultInlineThreshold = 1100; -// Information about a GPU architecture for the backend. -struct GpuBackendInfo { - string libdevice_name; - string sm_name; -}; - -// Maps supported CUDA compute capability to a libdevice file to link for this -// capability. -std::map gpu_info_map = { - {"compute_20", {"libdevice.compute_20.10.bc", "sm_20"}}, - {"compute_30", {"libdevice.compute_30.10.bc", "sm_30"}}, - {"compute_35", {"libdevice.compute_35.10.bc", "sm_35"}}, - - // NVIDIA does not provide a separate libdevice for CC 3.7, but we can use - // the one for 3.5. - {"compute_37", {"libdevice.compute_35.10.bc", "sm_37"}}, -}; - -// Validate the --gpu_architecture command-line flag. -static void ValidateGPUArchitecture(const string& value) { - if (!gpu_info_map.count(value)) { - LOG(FATAL) << "value for --gpu_architecture must be compute_{20,30,35,37}"; +// Gets the libdevice filename for a particular compute capability. When +// presented with a GPU we don't recognize, we just return the libdevice from +// compute_20. +static string GetLibdeviceFilename(std::pair compute_capability) { + // There are only four libdevice files: compute_{20,30,35,50}. Each GPU + // version gets mapped to one of these. Note in particular that sm_60 and + // sm_61 map to libdevice.compute_30. + static auto* m = new std::map, int>({{{2, 0}, 20}, + {{2, 1}, 20}, + {{3, 0}, 30}, + {{3, 2}, 30}, + {{3, 5}, 35}, + {{3, 7}, 35}, + {{5, 0}, 50}, + {{5, 2}, 50}, + {{5, 3}, 50}, + {{6, 0}, 30}, + {{6, 1}, 30}, + {{6, 2}, 30}}); + int libdevice_version = 20; + auto it = m->find(compute_capability); + if (it != m->end()) { + libdevice_version = it->second; + } else { + LOG(WARNING) << "Unknown compute capability (" << compute_capability.first + << ", " << compute_capability.second << ") ." + << "Defaulting to libdevice for compute_" << libdevice_version; } + return tensorflow::strings::StrCat("libdevice.compute_", libdevice_version, + ".10.bc"); +} + +// Gets the GPU name as it's known to LLVM for a given compute capability. If +// we see an unrecognized compute capability, we return "sm_20". +static string GetSmName(std::pair compute_capability) { + static auto* m = new std::map, int>({{{2, 0}, 20}, + {{2, 1}, 21}, + {{3, 0}, 30}, + {{3, 2}, 32}, + {{3, 5}, 35}, + {{3, 7}, 37}, + {{5, 0}, 50}, + {{5, 2}, 52}, + {{5, 3}, 53}, + {{6, 0}, 60}, + {{6, 1}, 61}, + {{6, 2}, 62}}); + int sm_version = 20; + auto it = m->find(compute_capability); + if (it != m->end()) { + sm_version = it->second; + } else { + LOG(WARNING) << "Unknown compute capability (" << compute_capability.first + << ", " << compute_capability.second << ") ." + << "Defaulting to telling LLVM that we're compiling for sm_" + << sm_version; + } + return tensorflow::strings::StrCat("sm_", sm_version); } // Convenience function for producing a name of a temporary compilation product @@ -135,8 +172,10 @@ std::unique_ptr GetTargetMachine( } TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); - // Set options from hlo_module_config (specifically, fast-math flags). - llvm_ir::SetTargetOptions(hlo_module_config, &target_options); + llvm_ir::SetTargetOptions( + /*fast_math_enabled=*/hlo_module_config.debug_options() + .xla_enable_fast_math(), + &target_options); // Enable FMA synthesis if desired. legacy_flags::GpuBackendLibFlags* flags = @@ -270,39 +309,41 @@ bool CouldNeedLibdevice(const llvm::Module& module) { } // Links libdevice into the given module if the module needs libdevice. -tensorflow::Status LinkLibdeviceIfNecessary(const string& libdevice_dir_path, - llvm::Module* module) { +tensorflow::Status LinkLibdeviceIfNecessary( + llvm::Module* module, std::pair compute_capability, + const string& libdevice_dir_path) { if (!CouldNeedLibdevice(*module)) { return tensorflow::Status::OK(); } llvm::Linker linker(*module); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - ValidateGPUArchitecture(flags->gpu_architecture); - string libdevice_bc_filename = - gpu_info_map[flags->gpu_architecture].libdevice_name; - string libdevice_bc_fullpath = - tensorflow::io::JoinPath(libdevice_dir_path, libdevice_bc_filename); - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->FileExists(libdevice_bc_fullpath)); + string libdevice_path = tensorflow::io::JoinPath( + libdevice_dir_path, GetLibdeviceFilename(compute_capability)); + TF_RETURN_IF_ERROR(tensorflow::Env::Default()->FileExists(libdevice_path)); + VLOG(1) << "Linking with libdevice from: " << libdevice_path; std::unique_ptr libdevice_module = - LoadIRModule(libdevice_bc_fullpath, &module->getContext()); - VLOG(1) << "Linking with libdevice from: " << libdevice_bc_fullpath; - if (linker.linkInModule(std::move(libdevice_module), - llvm::Linker::Flags::InternalizeLinkedSymbols | - llvm::Linker::Flags::LinkOnlyNeeded)) { - LOG(FATAL) << "Error linking libdevice from " << libdevice_bc_fullpath; + LoadIRModule(libdevice_path, &module->getContext()); + if (linker.linkInModule( + std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded, + [](Module& M, const StringSet<>& GVS) { + internalizeModule(M, [&M, &GVS](const GlobalValue& GV) { + return !GV.hasName() || (GVS.count(GV.getName()) == 0); + }); + })) { + return tensorflow::errors::Internal(tensorflow::strings::StrCat( + "Error linking libdevice from ", libdevice_path)); } return tensorflow::Status::OK(); } StatusOr CompileModuleToPtx(llvm::Module* module, + std::pair compute_capability, const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path) { // Link the input module with libdevice, to pull in implementations of some // builtins. - TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(libdevice_dir_path, module)); + TF_RETURN_IF_ERROR( + LinkLibdeviceIfNecessary(module, compute_capability, libdevice_dir_path)); legacy_flags::GpuBackendLibFlags* flags = legacy_flags::GetGpuBackendLibFlags(); @@ -351,17 +392,14 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // Figure out the exact name of the processor as known to the NVPTX backend // from the gpu_architecture flag. - ValidateGPUArchitecture(flags->gpu_architecture); - string cpu_name = gpu_info_map[flags->gpu_architecture].sm_name; - - std::unique_ptr target_machine = - GetTargetMachine(target_triple, cpu_name, hlo_module_config); + std::unique_ptr target_machine = GetTargetMachine( + target_triple, GetSmName(compute_capability), hlo_module_config); module_passes.add(llvm::createTargetTransformInfoWrapperPass( target_machine->getTargetIRAnalysis())); // The LLVM IR verifier performs sanity checking on the IR. This helps // discover problems and report them in a meaningful manner, rather than let - // later passes report obscure assertions becasue of unfulfilled invariants. + // later passes report obscure assertions because of unfulfilled invariants. module_passes.add(llvm::createVerifierPass()); // Create the function-level pass manager. It needs data layout information @@ -370,9 +408,9 @@ StatusOr CompileModuleToPtx(llvm::Module* module, AddOptimizationPasses(flags->opt_level, /*size_level=*/0, target_machine.get(), &module_passes, &function_passes); - // Loop unrolling exposes more opportunites for SROA. Therefore, we run SROA + // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA // again after the standard optimization passes [http://b/13329423]. - // TODO(jingyue): SROA may further expose more optimization opportunites, such + // TODO(jingyue): SROA may further expose more optimization opportunities, such // as more precise alias analysis and more function inlining (SROA may change // the inlining cost of a function). For now, running SROA already emits good // enough code for the evaluated benchmarks. We may want to run more @@ -466,6 +504,7 @@ void GPUBackendInit() { } // namespace StatusOr CompileToPtx(llvm::Module* module, + std::pair compute_capability, const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path) { static std::once_flag backend_init_flag; @@ -477,7 +516,8 @@ StatusOr CompileToPtx(llvm::Module* module, "Compile module " + llvm_ir::AsString(module->getName()), /*vlog_level=*/2); TF_ASSIGN_OR_RETURN( - ptx, CompileModuleToPtx(module, hlo_module_config, libdevice_dir_path)); + ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config, + libdevice_dir_path)); } return ptx; } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h index cf6f3197bb7..fd894072170 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -18,6 +18,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ #include +#include #include "external/llvm/include/llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -28,14 +29,15 @@ limitations under the License. namespace xla { namespace gpu { -// The Compile.* interfaces each create their own llvm::LLVMContext objects for -// thread safety, but note that LLVM's multithreaded support is very -// preliminary; multithreaded use is not recommended at this time. -// // Compiles the argument module and returns it. libdevice_dir_path is the parent // directory of the libdevice bitcode libraries. The contents of the module may // be changed. +// +// The Compile.* interfaces each create their own llvm::LLVMContext objects for +// thread safety, but note that LLVM's multithreaded support is very +// preliminary; multithreaded use is not recommended at this time. StatusOr CompileToPtx(llvm::Module* module, + std::pair compute_capability, const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc index c10346bbc23..72f6cfd2d60 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -28,7 +28,8 @@ limitations under the License. namespace { static void DieWithSMDiagnosticError(llvm::SMDiagnostic* diagnostic) { - LOG(FATAL) << diagnostic->getLineNo() << ":" << diagnostic->getColumnNo() + LOG(FATAL) << diagnostic->getFilename().str() << ":" + << diagnostic->getLineNo() << ":" << diagnostic->getColumnNo() << ": " << diagnostic->getMessage().str(); } diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index c645e84aa4f..a12a9a71682 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -80,6 +80,7 @@ HloInstruction* MaybePaddedAndSlicedInput( std::vector start_indices(input->shape().dimensions_size(), 0); std::vector limit_indices(input->shape().dimensions().begin(), input->shape().dimensions().end()); + std::vector strides(input->shape().dimensions_size(), 1); for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { int64 dim = conv_dnums.spatial_dimensions(i); // If dimension "dim" has negative padding, increase the start index or @@ -92,9 +93,9 @@ HloInstruction* MaybePaddedAndSlicedInput( input = computation->AddInstruction(HloInstruction::CreateSlice( ShapeInference::InferSliceShape(input->shape(), start_indices, - limit_indices) + limit_indices, strides) .ConsumeValueOrDie(), - input, start_indices, limit_indices)); + input, start_indices, limit_indices, strides)); } return input; @@ -354,6 +355,8 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( std::vector limit_indices( new_backward_conv->shape().dimensions().begin(), new_backward_conv->shape().dimensions().end()); + std::vector strides(new_backward_conv->shape().dimensions_size(), + 1LL); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); @@ -373,13 +376,13 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( // Replace the old backward convolution with the slice. CHECK(ShapeUtil::Compatible( ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, - limit_indices) + limit_indices, strides) .ConsumeValueOrDie(), backward_conv->shape())); TF_CHECK_OK(computation->ReplaceWithNewInstruction( backward_conv, HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv, - start_indices, limit_indices))); + start_indices, limit_indices, strides))); return true; } diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 8ac4c599663..8f7fce884ac 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -33,7 +33,7 @@ namespace gpu { enum class PartitionStrategy { // Optimized for latency by allowing maximum number of registers per thread. kLatency, - // Optimized for throughtput. This may limit registers per thread and cause + // Optimized for throughput. This may limit registers per thread and cause // longer latency. kThroughput }; diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 28d47d2b0f8..a5230b3e8e9 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -45,10 +45,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build(dot2)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build(dot2)); - std::unique_ptr assignment = AssignStreams(module); + std::unique_ptr assignment = AssignStreams(*module); EXPECT_EQ(assignment->StreamNumberForHlo(*dot1), assignment->StreamNumberForHlo(*dot2)); } @@ -66,10 +66,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build(add)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build(add)); - std::unique_ptr assignment = AssignStreams(module); + std::unique_ptr assignment = AssignStreams(*module); EXPECT_NE(assignment->StreamNumberForHlo(*dot1), assignment->StreamNumberForHlo(*dot2)); } @@ -86,6 +86,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { // d40 -- layer 4 HloComputation::Builder builder("entry_computation"); std::vector params; + params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); @@ -109,10 +110,10 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build(d40)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build(d40)); - std::unique_ptr assignment = AssignStreams(module); + std::unique_ptr assignment = AssignStreams(*module); // The two dots on layer 1 are concurrent. EXPECT_NE(assignment->StreamNumberForHlo(*d10), assignment->StreamNumberForHlo(*d11)); @@ -130,3 +131,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { } // namespace gpu } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc deleted file mode 100644 index 3cf5dd021a1..00000000000 --- a/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h" - -#include "tensorflow/compiler/xla/map_util.h" - -namespace xla { -namespace gpu { - -namespace { -int64 RoundUpToAlignment(int64 value) { - // Any address of a variable residing in global memory or returned by one of - // the memory allocation routines from the driver or runtime API is always - // aligned to at least 256 bytes. - // (http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses) - static constexpr int64 kCudaMallocAlignment = 256; - return (value + kCudaMallocAlignment - 1) / kCudaMallocAlignment * - kCudaMallocAlignment; -} -} // namespace - -TempBufferOffsets::TempBufferOffsets( - const BufferAssignment& buffer_assignment) { - total_size_of_temp_buffers_ = 0; - for (auto i = 0; i < buffer_assignment.Allocations().size(); ++i) { - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); - if (allocation.IsPreallocatedTempBuffer()) { - InsertOrDie(&buffer_index_to_offset_, i, total_size_of_temp_buffers_); - total_size_of_temp_buffers_ += RoundUpToAlignment(allocation.size()); - } - } -} - -int64 TempBufferOffsets::GetOffset(BufferAllocation::Index index) const { - return FindOrDie(buffer_index_to_offset_, index); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h deleted file mode 100644 index 05aca99bf34..00000000000 --- a/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TEMP_BUFFER_OFFSETS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TEMP_BUFFER_OFFSETS_H_ - -#include - -#include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace gpu { - -// GpuExecutable merges all temporary buffers into one memory block. This class -// stores the offset of each temporary buffer in that memory block. -class TempBufferOffsets { - public: - explicit TempBufferOffsets(const BufferAssignment& buffer_assignment); - - int64 GetOffset(BufferAllocation::Index index) const; - int64 TotalSizeInBytes() const { return total_size_of_temp_buffers_; } - - private: - std::map buffer_index_to_offset_; - - // The total size of all temporary buffers. This includes paddings that are - // necessary for alignment. - int64 total_size_of_temp_buffers_; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TEMP_BUFFER_OFFSETS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 3ced3484007..0ff27888ad7 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -44,6 +44,7 @@ class Thunk { kConvolution, kCopy, kGemm, + kInfeed, kKernel, kSequential, kTuple, diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 8addcd87eaa..bdb062837c5 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -79,7 +79,7 @@ ThunkSchedule::ThunkSchedule( void ThunkSchedule::RemoveRedundantDependencyEdges() { std::unordered_map thunk_to_total_order; - for (auto i = 0; i < thunk_total_order_.size(); ++i) { + for (int i = 0; i < thunk_total_order_.size(); ++i) { InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i); } diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 323775b3e84..bd65e72393a 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -25,7 +25,7 @@ namespace gpu { tensorflow::Status TupleThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { std::vector tuple_element_buffer_addresses; - for (BufferAllocation::Index tuple_element_buffer : tuple_element_buffers_) { + for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { tuple_element_buffer_addresses.push_back( buffer_allocations.GetDeviceAddress(tuple_element_buffer).opaque()); } diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index ca0404286fb..3b1a4963285 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -33,9 +33,9 @@ namespace gpu { // issue (b/31336476). class TupleThunk : public Thunk { public: - TupleThunk(tensorflow::gtl::ArraySlice + TupleThunk(tensorflow::gtl::ArraySlice tuple_element_buffers, - BufferAllocation::Index dest_buffer, + const BufferAllocation::Slice& dest_buffer, const HloInstruction* hlo_instruction) : Thunk(Kind::kTuple, hlo_instruction), tuple_element_buffers_(tuple_element_buffers.begin(), @@ -50,8 +50,8 @@ class TupleThunk : public Thunk { perftools::gputools::Stream* stream) override; private: - std::vector tuple_element_buffers_; - const BufferAllocation::Index dest_buffer_; + const std::vector tuple_element_buffers_; + const BufferAllocation::Slice dest_buffer_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 36883e4920a..0d2412096ab 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -22,10 +22,11 @@ limitations under the License. namespace xla { namespace gpu { -WhileThunk::WhileThunk(BufferAllocation::Index condition_result_buffer_index, - std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence, - const HloInstruction* hlo) +WhileThunk::WhileThunk( + const BufferAllocation::Slice& condition_result_buffer_index, + std::unique_ptr condition_thunk_sequence, + std::unique_ptr body_thunk_sequence, + const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), condition_result_buffer_index_(condition_result_buffer_index), condition_thunk_sequence_(MakeUnique( diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 1658cdaf87f..95ed5497cea 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -38,7 +38,7 @@ namespace gpu { class WhileThunk : public Thunk { public: // Constructs a WhileThunk to compute while instruction 'hlo'. - WhileThunk(BufferAllocation::Index condition_result_buffer_index, + WhileThunk(const BufferAllocation::Slice& condition_result_buffer_index, std::unique_ptr condition_thunk_sequence, std::unique_ptr body_thunk_sequence, const HloInstruction* hlo); @@ -51,7 +51,7 @@ class WhileThunk : public Thunk { perftools::gputools::Stream* stream) override; private: - BufferAllocation::Index condition_result_buffer_index_; + const BufferAllocation::Slice condition_result_buffer_index_; std::unique_ptr condition_thunk_sequence_; std::unique_ptr body_thunk_sequence_; }; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index ec75e135814..06b01d311da 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -37,7 +37,7 @@ namespace { // patterns to match. // // Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifing the index and HloOpcode +// of type ExprTree). Operands can be added by specifying the index and HloOpcode // of the operand. // // For example, the following computation: @@ -122,10 +122,12 @@ class ExprTree { Status Match(const HloInstruction* instruction, TaggedInstructionMap* tagged_instructions) const { if (opcode_ != instruction->opcode()) { - return InvalidArgument("Unexpected opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); + return InvalidArgument("got opcode %s, want %s", + HloOpcodeString(instruction->opcode()).c_str(), + HloOpcodeString(opcode_).c_str()); } + VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; if (!tag_.empty()) { tagged_instructions->insert({tag_, instruction}); } @@ -166,7 +168,7 @@ class MatcherBase { virtual ~MatcherBase() {} // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first succesful match, error status otherwise. + // Returns OK on the first successful match, error status otherwise. virtual tensorflow::Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { @@ -238,7 +240,7 @@ class MatcherBase { // class WhileConditionComputationMatcher : public MatcherBase { public: - WhileConditionComputationMatcher(const HloComputation* computation) + explicit WhileConditionComputationMatcher(const HloComputation* computation) : computation_(computation) { expr_trees_.emplace_back(BuildCondExprTree()); } @@ -275,6 +277,7 @@ class WhileConditionComputationMatcher : public MatcherBase { } Status MatchExprTree(const ExprTree& expr_tree) override { + VLOG(2) << "MATCHING while condition"; ExprTree::TaggedInstructionMap tagged_instructions; TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), &tagged_instructions)); @@ -344,10 +347,6 @@ class WhileInitOperandMatcher : public MatcherBase { // // Const // | - // Tuple1 - // | - // GTE0 - // | // Copy // | // Tuple0 @@ -355,15 +354,15 @@ class WhileInitOperandMatcher : public MatcherBase { // While // ExprTree BuildInitExprTree() { - ExprTree gte0(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kConstant, "loop_start"))); - return ExprTree(HloOpcode::kWhile, "while", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, gte0))); + return ExprTree( + HloOpcode::kWhile, "while", + ExprTree(HloOpcode::kTuple, tuple_index_, + ExprTree(HloOpcode::kCopy, + ExprTree(HloOpcode::kConstant, "loop_start")))); } Status MatchExprTree(const ExprTree& expr_tree) override { + VLOG(2) << "MATCHING while init"; ExprTree::TaggedInstructionMap tagged_instructions; TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); @@ -375,14 +374,6 @@ class WhileInitOperandMatcher : public MatcherBase { while_hlo->name().c_str()); } - // Get tagged GTE instruction and check 'tuple_index_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* gte, - GetTaggedInstruction("gte", tagged_instructions)); - if (gte->tuple_index() != tuple_index_) { - return InvalidArgument("Unexpected tuple index instruction : %s", - gte->name().c_str()); - } - // Get tagged Constant instruction and parse 'loop_start_'. TF_ASSIGN_OR_RETURN( const HloInstruction* const_hlo, @@ -427,10 +418,6 @@ class WhileBodyComputationMatcher : public MatcherBase { // \ / \ / // Fusion -----------> Add // | - // Tuple1 - // | - // GTE0 - // | // Copy // | // Tuple0 @@ -450,15 +437,13 @@ class WhileBodyComputationMatcher : public MatcherBase { fusion.SetFusedRoot(fused_root); // Build top-level computation. - ExprTree tuple0( - HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kTuple, tuple_index_, fusion)))); + ExprTree tuple0(HloOpcode::kTuple, tuple_index_, + ExprTree(HloOpcode::kCopy, fusion)); return tuple0; } Status MatchExprTree(const ExprTree& expr_tree) override { + VLOG(2) << "MATCHING while body"; ExprTree::TaggedInstructionMap tagged_instructions; TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), &tagged_instructions)); diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index ddf9676e378..e82491fd6f9 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -17,16 +17,20 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { namespace { +using ::testing::Eq; +using ::testing::HasSubstr; + class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() - : module_(TestName()), + : module_(CreateNewModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), loop_state_shape_(ShapeUtil::MakeTupleShape( @@ -98,26 +102,26 @@ class WhileTransformerTest : public HloTestBase { HloInstruction::CreateTuple({data_init, induction_var_init})); auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition, body, loop_state_init)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); return while_hlo; } void RunFusionPasses() { // Run standard fusion passes. EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false) - .Run(&module_) + .Run(module_.get()) .ValueOrDie()); EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true) - .Run(&module_) + .Run(module_.get()) .ValueOrDie()); } void RunCopyInsertionPass() { CopyInsertion copy_insertion; - EXPECT_IS_OK(copy_insertion.Run(&module_).status()); + EXPECT_IS_OK(copy_insertion.Run(module_.get()).status()); } - HloModule module_; + std::unique_ptr module_; Shape induction_variable_shape_; Shape data_shape_; Shape loop_state_shape_; @@ -127,74 +131,72 @@ class WhileTransformerTest : public HloTestBase { TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = - module_.AddEmbeddedComputation(BuildConditionComputation(0, 10)); - auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); + module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); + auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); auto while_hlo = BuildWhileInstruction(condition, body, 0, 0); // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - EXPECT_TRUE(result.ok()); + ASSERT_TRUE(result.ok()); // Check results. - auto tuple = result.ConsumeValueOrDie(); - EXPECT_EQ(0, std::get<0>(tuple)); - EXPECT_EQ(10, std::get<1>(tuple)); - EXPECT_EQ(1, std::get<2>(tuple)); + EXPECT_THAT(result.ConsumeValueOrDie(), + Eq(std::tuple(0, 10, 1))); } TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = - module_.AddEmbeddedComputation(BuildConditionComputation(1, 10)); - auto body = module_.AddEmbeddedComputation(BuildBodyComputation(1, 0, 1)); + module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); + auto body = module_->AddEmbeddedComputation(BuildBodyComputation(1, 0, 1)); auto while_hlo = BuildWhileInstruction(condition, body, 1, 0); // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - EXPECT_TRUE(result.ok()); + ASSERT_TRUE(result.ok()); // Check results. - auto tuple = result.ConsumeValueOrDie(); - EXPECT_EQ(0, std::get<0>(tuple)); - EXPECT_EQ(10, std::get<1>(tuple)); - EXPECT_EQ(1, std::get<2>(tuple)); + EXPECT_THAT(result.ConsumeValueOrDie(), + Eq(std::tuple(0, 10, 1))); } TEST_F(WhileTransformerTest, InvalidLoopLimit) { // Build computation with invalid loop limit. auto condition = - module_.AddEmbeddedComputation(BuildConditionComputation(0, 5)); - auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); + module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); + auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); auto while_hlo = BuildWhileInstruction(condition, body, 0, 10); // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - EXPECT_FALSE(result.ok()); - EXPECT_MATCH( - result.status().error_message(), - testing::ContainsRegex("Loop start must be less than loop limit.")); + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status().error_message(), + HasSubstr("Loop start must be less than loop limit.")); } TEST_F(WhileTransformerTest, InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = - module_.AddEmbeddedComputation(BuildConditionComputation(0, 10)); - auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, -1)); + module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); + auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, -1)); auto while_hlo = BuildWhileInstruction(condition, body, 0, 0); // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - EXPECT_FALSE(result.ok()); - EXPECT_MATCH( - result.status().error_message(), - testing::ContainsRegex("Loop increment must greater than zero.")); + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status().error_message(), + HasSubstr("Loop increment must greater than zero.")); } } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu_transfer_manager.cc new file mode 100644 index 00000000000..4b8d190a463 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu_transfer_manager.cc @@ -0,0 +1,102 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu_transfer_manager.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +// TODO(b/30467474) Once GPU infeed implementation settles, consider +// folding back the cpu and gpu infeed implementations into a generic +// one if possible. +GpuTransferManager::GpuTransferManager() + : GenericTransferManager(se::cuda::kCudaPlatformId) {} + +Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, + const Literal& literal) { + const Shape& shape = literal.shape(); + VLOG(2) << "Transferring literal shape to infeed: " + << ShapeUtil::HumanString(shape); + + // TODO(b/30467474) handle tuples. + if (ShapeUtil::IsTuple(shape)) { + return Unimplemented("Infeed with a tuple shape is not supported: %s", + ShapeUtil::HumanString(literal.shape()).c_str()); + } + + int64 size = GetByteSizeRequirement(shape); + if (size > std::numeric_limits::max()) { + return Unimplemented("Infeed shape is too large: %s needs %lld bytes", + ShapeUtil::HumanString(literal.shape()).c_str(), size); + } + + if (size == 0) { + return Unimplemented("Infeed shape %s needs 0 bytes", + ShapeUtil::HumanString(literal.shape()).c_str()); + } + + gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); + se::Stream* stream = infeed_manager->GetStream(executor); + if (stream == nullptr) { + return InternalError("Failed to obtain a stream"); + } + + gpu::InfeedBuffer* buffer = new gpu::InfeedBuffer(executor, size); + stream->ThenMemcpy(buffer->device_memory(), + LiteralUtil::InternalData(literal), size); + + VLOG(2) << "Queued infeed data on stream " << stream; + + if (!stream->BlockHostUntilDone()) { + buffer->Done(); + return InternalError("Failed to complete data transfer on stream %p", + stream); + } + + infeed_manager->EnqueueBuffer(buffer); + + VLOG(2) << "Infeed data transferred"; + return Status::OK(); +} + +} // namespace xla + +static std::unique_ptr CreateGpuTransferManager() { + return xla::MakeUnique(); +} + +static bool InitModule() { + xla::TransferManager::RegisterTransferManager(se::cuda::kCudaPlatformId, + &CreateGpuTransferManager); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu_transfer_manager.h new file mode 100644 index 00000000000..6dfe7ba0295 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu_transfer_manager.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ + +#include + +#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// An implementation of the XLA GenericTransferManager that +// handles GPU-specific infeed. +class GpuTransferManager : public GenericTransferManager { + public: + GpuTransferManager(); + ~GpuTransferManager() override {} + + Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, + const Literal& literal) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc new file mode 100644 index 00000000000..86f62accd3b --- /dev/null +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -0,0 +1,607 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/heap_simulator.h" + +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +using tensorflow::gtl::FlatMap; +using tensorflow::gtl::FlatSet; + +namespace { + +// Returns the set of buffers that may be sources of all operands of the given +// instruction. The returned buffers are guaranteed to have no duplicates, and +// to be sorted in a deterministic order. +std::vector UniqueOperandSourceBuffers( + const HloInstruction* instruction, + const TuplePointsToAnalysis& points_to_analysis) { + FlatSet buffers; + for (const HloInstruction* operand : instruction->operands()) { + FlatSet sources = + points_to_analysis.GetPointsToSet(operand).CreateFlattenedSet(); + buffers.insert(sources.begin(), sources.end()); + } + std::vector sorted(buffers.begin(), buffers.end()); + std::sort(sorted.begin(), sorted.end(), + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); + return sorted; +} + +} // namespace + +/*static*/ +StatusOr HeapSimulator::Run( + std::unique_ptr algorithm, const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const FlatSet* buffers_to_assign) { + HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign, + &module_sequence); + const HloComputation* entry_computation = module.entry_computation(); + const std::vector& instruction_sequence = + FindOrDie(module_sequence, entry_computation); + TF_RETURN_IF_ERROR(heap.RunComputation( + *entry_computation, instruction_sequence, points_to_analysis)); + return heap.Finish(); +} + +/*static*/ +StatusOr HeapSimulator::Run( + std::unique_ptr algorithm, const HloComputation& computation, + const std::vector& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const FlatSet* buffers_to_assign) { + HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign, + /*module_sequence=*/nullptr); + TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, + points_to_analysis)); + return heap.Finish(); +} + +// Runs a heap simulation for the given 'computation', assuming the given +// 'instruction_sequence'. +Status HeapSimulator::RunComputation( + const HloComputation& computation, + const std::vector& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis) { + // The goal here is to minimize memory usage, assuming the given sequential + // ordering of instructions. The strategy is to walk through the instruction + // sequence, calling Alloc and Free on the underlying heap algorithm. The + // heap algorithm takes care of packing and reducing fragmentation. + // + // 'live_buffers' tracks the liveness of each buffer that we assign, by + // associating it with a set of HloInstructions that need to be visited. When + // the set becomes empty, the buffer is no longer used, and can be freed. + FlatMap> live_buffers; + + const HloInstruction* root = computation.root_instruction(); + FlatSet output_source_buffers = + points_to_analysis.GetPointsToSet(root).CreateFlattenedSet(); + + for (const HloInstruction* instruction : instruction_sequence) { + const std::vector& buffers_defined_by_instruction = + points_to_analysis.GetBuffersDefinedByInstruction(instruction); + + // Initialize live_buffers for each buffer that we're going to assign. The + // set of instructions that need to be visited contains all users of all + // aliases. The alias itself is not necessary; if it has users, the users + // are necessarily scheduled after the alias. And if it has no users, it is + // either a dead value or an output, both of which are handled below. + // + // We ignore control dependencies here. The reasoning is that the control + // dependencies have already been accounted for in the ordering of the given + // 'instruction_sequence', and should not otherwise artificially extend the + // lifetime of buffers that aren't already connected by a data dependency. + std::vector dead_buffers_to_free; + for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + if (IgnoreBuffer(buffer)) { + continue; + } + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + const std::vector& users = + alias.instruction()->users(); + if (!users.empty()) { + live_buffers[buffer].insert(users.begin(), users.end()); + } + } + + // Add a nullptr sentry to ensure entry parameters and output source + // buffers are not freed until the very end. + const bool entry_parameter = + &computation == computation.parent()->entry_computation() && + buffer->instruction()->opcode() == HloOpcode::kParameter; + const bool output = output_source_buffers.count(buffer) > 0; + if (entry_parameter || output) { + live_buffers[buffer].insert(nullptr); + } + + // If the buffer has no users and isn't an entry parameter or output, it + // must be a dead value. + if (live_buffers.count(buffer) == 0) { + dead_buffers_to_free.push_back(buffer); + } + } + + // Update live_buffers to indicate we've visited this instruction; this is + // the inverse of the initialization logic. We erase this instruction from + // all source buffers of all operands of this instruction. Buffers that + // have no instructions left to visit are moved from live_buffers to + // operand_buffers_to_free. + std::vector operand_buffers_to_free; + for (const LogicalBuffer* operand_buffer : + UniqueOperandSourceBuffers(instruction, points_to_analysis)) { + if (IgnoreBuffer(operand_buffer)) { + continue; + } + live_buffers[operand_buffer].erase(instruction); + if (live_buffers[operand_buffer].empty()) { + live_buffers.erase(operand_buffer); + operand_buffers_to_free.push_back(operand_buffer); + } + } + + // Allocate buffers defined by this instruction. This is the latest point + // that we can allocate; right before the buffer is first used. This must + // happen before dead or operand buffers are freed; the instruction reads + // the operand buffers to produce its output. + // + // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer + // that we should assign. + for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + if (IgnoreBuffer(buffer)) { + continue; + } + + // Check whether the buffer can share with one of its operands; we can + // save memory by sharing the buffer, rather than allocating a new one. + // We can only share with the operand buffer if it is about to be freed; + // we must be the last user of the buffer. + bool shared = false; + for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { + if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && + buffer->instruction()->opcode() != HloOpcode::kCopy && + CanShareOperandBufferWithUser( + operand_buffer->instruction(), operand_buffer->index(), + buffer->instruction(), buffer->index(), points_to_analysis)) { + ShareBuffer(buffer, operand_buffer, instruction); + shared = true; + break; + } + } + + if (!shared) { + Alloc(buffer, instruction); + } + } + + // If the whole module is sequential, we can save memory by running the + // heap-simulation for sub-computations inline. E.g. the buffers for the + // condition and body of a kWhile instruction are only live for the duration + // of the instruction itself. + // + // The order that the sub-computations are simulated does not affect + // correctness; since the whole module is sequential, we know that the + // sub-computations will never be run concurrently. + if (module_sequence_ != nullptr) { + if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kWhile) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + const std::vector& called_sequence = + FindOrDie(*module_sequence_, called_computation); + TF_RETURN_IF_ERROR(RunComputation( + *called_computation, called_sequence, points_to_analysis)); + } + } + + // Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are + // assigned "thread-local" allocations, meaning their buffers are not + // allocated up-front at the beginning of the computation. + } + + // Free buffers that are no longer live. This is the earliest point that we + // can de-allocate; right after the last use of the buffer. + for (const LogicalBuffer* buffer : dead_buffers_to_free) { + Free(buffer, instruction); + } + for (const LogicalBuffer* buffer : operand_buffers_to_free) { + Free(buffer, instruction); + } + } + + // Any remaining live buffers must be entry parameters or output source + // buffers, which had a nullptr sentry added. Free them now. + for (const auto& buffer_pending : live_buffers) { + const LogicalBuffer* buffer = buffer_pending.first; + const FlatSet& pending = buffer_pending.second; + CHECK_EQ(pending.size(), 1) << *buffer; + CHECK(*pending.begin() == nullptr) << *buffer; + Free(buffer, root); + } + + return Status::OK(); +} + +HeapSimulator::HeapSimulator( + std::unique_ptr algorithm, + const LogicalBuffer::SizeFunction& size_fn, + const FlatSet* buffers_to_assign, + const SequentialHloOrdering::HloModuleSequence* module_sequence) + : no_fragmentation_stats_(MakeUnique()), + algorithm_(std::move(algorithm)), + size_fn_(size_fn), + buffers_to_assign_(buffers_to_assign), + module_sequence_(module_sequence) { + debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); +} + +HeapSimulator::~HeapSimulator() {} + +bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { + // Buffers for constants are ignored, as with BufferAssigner. Also ignore + // buffers that we're not meant to assign. + // + // TODO(b/32248867): For consistency, constants should get allocations. + return buffer->instruction()->opcode() == HloOpcode::kConstant || + (buffers_to_assign_ != nullptr && + buffers_to_assign_->count(buffer) == 0); +} + +// Alloc always calls the underlying heap algorithm. +void HeapSimulator::Alloc(const LogicalBuffer* buffer, + const HloInstruction* instruction) { + CHECK(allocated_buffers_.count(buffer) == 0) + << "Alloc called on allocated buffer: " << *buffer; + CHECK(freed_buffers_.count(buffer) == 0) + << "Alloc called on freed buffer: " << *buffer; + + allocated_buffers_.insert(buffer); + const int64 size = size_fn_(*buffer); + algorithm_->Alloc(buffer, size); + no_fragmentation_stats_->Alloc(buffer, size); + + FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction, + nullptr); +} + +// Free calls the underlying algorithm for non-shared buffers, and for shared +// buffers whose group liveness has expired. Shared group liveness is tracked +// by maintaining a refcount; the Free call on the last buffer in the group +// causes Free to be called on the underlying algorithm. +void HeapSimulator::Free(const LogicalBuffer* buffer, + const HloInstruction* instruction) { + auto shared_it = shared_buffers_.find(buffer); + if (shared_it != shared_buffers_.end()) { + std::shared_ptr group = shared_it->second; + --group->refcount; + if (group->refcount > 0) { + return; + } + CHECK_EQ(group->refcount, 0) + << "Free caused negative refcount on shared buffer: " << *buffer; + buffer = group->canonical; + } + + CHECK(allocated_buffers_.count(buffer) > 0) + << "Free called on non-allocated buffer: " << *buffer; + CHECK(freed_buffers_.count(buffer) == 0) + << "Free called on freed buffer: " << *buffer; + + freed_buffers_.insert(buffer); + const int64 size = size_fn_(*buffer); + algorithm_->Free(buffer, size); + no_fragmentation_stats_->Free(buffer, size); + + FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr); +} + +// ShareBuffer associates buffers with their SharedGroup in shared_buffers_. +// The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to +// Alloc. The 'shared' buffer must be a previously allocated or shared buffer. +// Both 'buffer' and 'shared' will be associated with the same SharedGroup. +void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, + const LogicalBuffer* shared, + const HloInstruction* instruction) { + CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) + << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; + CHECK(allocated_buffers_.count(buffer) == 0) + << "ShareBuffer called on allocated buffer: " << *buffer; + CHECK(freed_buffers_.count(buffer) == 0) + << "ShareBuffer called on freed buffer: " << *buffer; + CHECK(freed_buffers_.count(shared) == 0) + << "ShareBuffer called on freed shared buffer: " << *shared; + + const LogicalBuffer* canonical = nullptr; + auto shared_it = shared_buffers_.find(shared); + if (shared_it != shared_buffers_.end()) { + // The 'shared' buffer already has a group; it might be the canonical, but + // also might not be. Just add 'buffer' to the existing group. + std::shared_ptr group = shared_it->second; + canonical = group->canonical; + ++group->refcount; + shared_buffers_.emplace(buffer, group); + } else { + // The 'shared' buffer doesn't have a group; it must be the canonical. Add + // both 'buffer' and 'shared' to a new group. + CHECK(allocated_buffers_.count(shared) > 0) + << "ShareBuffer called on non-allocated shared buffer: " << *shared; + auto group = std::make_shared(); + canonical = shared; + group->canonical = canonical; + group->refcount = 2; + shared_buffers_.emplace(buffer, group); + shared_buffers_.emplace(shared, group); + } + + FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction, + canonical); +} + +HeapSimulator::Result HeapSimulator::Finish() { + Result result = algorithm_->Finish(); + + // Post-process the result to add chunks for shared buffers. An empty chunk + // map means that either no buffers were allocated, or the heap was only + // collecting statistics, e.g. NoFragmentationStatsHeap. + if (!result.chunk_map.empty()) { + for (const auto& share_pair : shared_buffers_) { + const LogicalBuffer* buffer = share_pair.first; + std::shared_ptr group = share_pair.second; + if (buffer != group->canonical) { + // The canonical must already exist in the chunk_map, since we called + // Alloc(canonical) on the underlying algorithm. Add non-canonical + // chunks with the same offset as the canonical. + Chunk chunk = FindOrDie(result.chunk_map, group->canonical); + chunk.size = size_fn_(*buffer); + result.chunk_map.emplace(buffer, chunk); + } + } + // If we were told to assign specific buffers, make sure we've assigned + // exactly that many buffers. + if (buffers_to_assign_ != nullptr) { + CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size()); + } + } + + // Fragmentation is the difference between the actual and ideal sizes. + const Result no_frag_result = no_fragmentation_stats_->Finish(); + result.fragmentation_size = result.heap_size - no_frag_result.heap_size; + + // Copy the debug trace we collected to the final result. + result.debug_trace.Swap(&debug_trace_); + + return result; +} + +void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, + const LogicalBuffer* buffer, + const HloInstruction* instruction, + const LogicalBuffer* share_with_canonical) { + HeapSimulatorTrace::Event* event = debug_trace_.add_events(); + event->set_kind(kind); + event->set_buffer_id(buffer->id()); + event->set_computation_name(instruction->parent()->name()); + event->set_instruction_name(instruction->name()); + if (kind == HeapSimulatorTrace::Event::SHARE_WITH) { + CHECK(share_with_canonical != nullptr); + event->set_share_with_canonical_id(share_with_canonical->id()); + } else { + CHECK(share_with_canonical == nullptr); + } +} + +void NoFragmentationStatsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { + current_heap_size_ += size; + if (current_heap_size_ > max_heap_size_) { + max_heap_size_ = current_heap_size_; + } +} + +void NoFragmentationStatsHeap::Free(const LogicalBuffer* buffer, int64 size) { + current_heap_size_ -= size; +} + +HeapSimulator::Result NoFragmentationStatsHeap::Finish() { + // The result.chunk_map is empty, since we only collect stats, and don't + // actually compute chunk assignments. + Result result; + result.heap_size = max_heap_size_; + return result; +} + +void DecreasingSizeRunsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { + SetMode(kAlloc); + run_.emplace_back(Op{buffer, size}); +} + +void DecreasingSizeRunsHeap::Free(const LogicalBuffer* buffer, int64 size) { + CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer; + SetMode(kFree); + run_.emplace_back(Op{buffer, size}); +} + +HeapSimulator::Result DecreasingSizeRunsHeap::Finish() { + CallAndDrainRun(); + return algorithm_->Finish(); +} + +void DecreasingSizeRunsHeap::SetMode(Mode mode) { + if (mode_ != mode) { + CallAndDrainRun(); + mode_ = mode; + } +} + +void DecreasingSizeRunsHeap::CallAndDrainRun() { + if (mode_ == kInit) { + CHECK(run_.empty()); + return; + } + + // Call ops in the run sorted by decreasing size, breaking ties by buffer id. + std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) { + if (a.size != b.size) { + return a.size > b.size; + } + return a.buffer->id() < b.buffer->id(); + }); + for (const Op& op : run_) { + if (mode_ == kAlloc) { + algorithm_->Alloc(op.buffer, op.size); + } else { + algorithm_->Free(op.buffer, op.size); + } + } + run_.clear(); +} + +void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { + // Degenerate case: 0-sized buffers are always allocated at offset 0. + if (size == 0) { + result_.chunk_map.emplace(buffer, Chunk{0, 0}); + } + + // First try to allocate from the best-fitting free chunk. + auto best_fit_it = free_.lower_bound(Chunk{0, size}); + while (best_fit_it != free_.end()) { + // Account for alignment. + const Chunk best = *best_fit_it; + const int64 new_offset = RoundUpToNearest(best.offset, alignment_); + const int64 new_end = new_offset + size; + if (new_end > best.chunk_end()) { + // We don't fit after accounting for alignment. + ++best_fit_it; + continue; + } + // The buffer is allocated a chunk out of the best-fitting free chunk. + free_.erase(best_fit_it); + result_.chunk_map.emplace(buffer, Chunk{new_offset, size}); + // Add remaining portions of the best-fitting free chunk back into free_. + AddFreeChunk(best.offset, new_offset - best.offset); + AddFreeChunk(new_end, best.chunk_end() - new_end); + return; + } + + // The buffer doesn't completely fit into any existing free chunk. If the + // last free chunk is adjacent to the end of the heap, allocate the buffer + // re-using that space, increasing the heap size. + // + // Allocating the buffer now causes the heap to grow by less than the buffer + // size, whereas if we allocated lazily in Free, the heap would grow by + // exactly the buffer size. However it's still a greedy heuristical approach; + // we might have ended up with a tighter packing by being lazy here. + // + // In theory we could also check if we could re-use space from the first free + // chunk and grow the heap at the front, and choose whether to grow from the + // front or back based on the amount of re-use. But that's more complicated, + // and these are all heuristics anyways, so it isn't implemented. + for (auto it = free_.begin(); it != free_.end(); ++it) { + if (it->chunk_end() == result_.heap_size) { + // Account for alignment in the last free chunk. + const Chunk last = *it; + const int64 new_offset = RoundUpToNearest(last.offset, alignment_); + if (new_offset >= last.chunk_end()) { + // There's no point in using the last free chunk if alignment causes us + // to skip over it anyways. + break; + } + // The buffer is allocated a chunk that includes the last free chunk. + free_.erase(it); + result_.chunk_map.emplace(buffer, Chunk{new_offset, size}); + // Add remaining portion of the last free chunk back into free_. + AddFreeChunk(last.offset, new_offset - last.offset); + // Grow the heap. + const int64 new_end = new_offset + size; + CHECK_GT(new_end, result_.heap_size); + CHECK_LT(new_end, result_.heap_size + size); + result_.heap_size = new_end; + return; + } + } + + // Otherwise lazily allocate the buffer in Free. + result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size}); +} + +void LazyBestFitHeap::Free(const LogicalBuffer* buffer, int64 size) { + auto alloc_it = result_.chunk_map.find(buffer); + CHECK(alloc_it != result_.chunk_map.end()) + << "Free called on non-allocated buffer: " << *buffer; + Chunk* alloc = &alloc_it->second; + CHECK_EQ(alloc->size, size) << "Free with mismatched sizes: " << *buffer; + if (alloc->offset != kLazyAllocOffset) { + // The buffer was already allocated in Alloc, do a normal free. + AddFreeChunk(alloc->offset, alloc->size); + } else { + // This buffer is lazily allocated, so we *can not* allocate out of existing + // free chunks, since that might cause interference between buffers. The + // buffer is allocated by growing the heap, accounting for alignment. + alloc->offset = RoundUpToNearest(result_.heap_size, alignment_); + const int64 new_end = alloc->chunk_end(); + AddFreeChunk(result_.heap_size, new_end - result_.heap_size); + CHECK_GT(new_end, result_.heap_size); + CHECK_GE(new_end, result_.heap_size + alloc->size); + result_.heap_size = new_end; + } +} + +void LazyBestFitHeap::AddFreeChunk(int64 offset, int64 size) { + if (size <= 0) { + return; + } + + // Coalesce the chunk with adjacent free chunks on either side. We must + // remove the free chunks from free_, since it's ordered by size. + Chunk chunk{offset, size}; + for (auto it = free_.begin(); it != free_.end();) { + if (it->chunk_end() == chunk.offset || it->offset == chunk.chunk_end()) { + chunk.offset = std::min(chunk.offset, it->offset); + chunk.size += it->size; + it = free_.erase(it); + } else { + ++it; + } + } + + // This is the only place we add free chunks to free_. It maintains the + // invariant that all free chunks are disjoint and non-adjacent. + free_.emplace(chunk); +} + +HeapSimulator::Result LazyBestFitHeap::Finish() { + if (!free_.empty()) { + // When Finish is called, all calls to Alloc must have had corresponding + // calls to Free, which will result in a single free chunk [0, heap_size). + CHECK_EQ(free_.size(), 1); + CHECK_EQ(free_.begin()->offset, 0); + CHECK_EQ(free_.begin()->size, result_.heap_size); + } + return result_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h new file mode 100644 index 00000000000..a03ad2f37cf --- /dev/null +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -0,0 +1,284 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Forward declare classes defined below. +class HeapAlgorithm; + +// HeapSimulator assigns buffer offsets by running a simulation of a regular +// memory heap with Alloc and Free calls. It only works for completely +// sequential instruction sequences. Unlike regular heaps, we have the +// advantage that the sequence of Alloc and Free calls is known up-front; we +// don't need to return the assignment of buffer offsets until the very end. +class HeapSimulator { + public: + // Chunk represents a contiguous piece of memory. Each LogicalBuffer will be + // associated with a chunk in the assignment result. + struct Chunk { + int64 offset; + int64 size; + + int64 chunk_end() const { return offset + size; } + }; + + // Result represents the result of the heap simulation. + struct Result { + // The assignment of buffers to chunks. + tensorflow::gtl::FlatMap chunk_map; + + // The total size in bytes of the heap, containing all assigned chunks. + int64 heap_size = 0; + + // The total size in bytes of heap fragmentation. + int64 fragmentation_size = 0; + + // A trace of heap simulation events. + HeapSimulatorTrace debug_trace; + }; + + // Run the heap simulation with the given algorithm, assuming the given + // module_sequence, which must contain a topologically-consistent total + // ordering of all instructions within each computation. The result is invalid + // if instructions are not run in exactly this sequence. + // + // Running heap simulation on the whole module tends to save memory, compared + // to running on a per-computation basis, since we can re-use buffer space for + // called sub-computations. + // + // If 'buffers_to_assign' is provided, only those buffers are assigned + // offsets, otherwise all buffers defined by the instructions are assigned. + static StatusOr Run( + std::unique_ptr algorithm, const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const tensorflow::gtl::FlatSet* buffers_to_assign = + nullptr); + + // Same as above, but runs on a single computation. The 'instruction_sequence' + // must contain a topologically-consistent total ordering of all instructions + // in the computation. The result is invalid if instructions are not run in + // exactly this sequence. + static StatusOr Run( + std::unique_ptr algorithm, + const HloComputation& computation, + const std::vector& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const tensorflow::gtl::FlatSet* buffers_to_assign = + nullptr); + + private: + // If 'module_sequence' is non-null, it is used to find kCall and kWhile + // sub-computations, and the heap simulation for those sub-computations will + // be run recursively. I.e. the simulation is run over the whole module. + HeapSimulator( + std::unique_ptr algorithm, + const LogicalBuffer::SizeFunction& size_fn, + const tensorflow::gtl::FlatSet* buffers_to_assign, + const SequentialHloOrdering::HloModuleSequence* module_sequence); + ~HeapSimulator(); + + Status RunComputation( + const HloComputation& computation, + const std::vector& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis); + + bool IgnoreBuffer(const LogicalBuffer* buffer) const; + void Alloc(const LogicalBuffer* buffer, const HloInstruction* instruction); + void Free(const LogicalBuffer* buffer, const HloInstruction* instruction); + void ShareBuffer(const LogicalBuffer* buffer, const LogicalBuffer* shared, + const HloInstruction* instruction); + Result Finish(); + + void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, + const LogicalBuffer* buffer, + const HloInstruction* instruction, + const LogicalBuffer* shared_with_canonical); + + const std::unique_ptr no_fragmentation_stats_; + const std::unique_ptr algorithm_; + const LogicalBuffer::SizeFunction size_fn_; + const tensorflow::gtl::FlatSet* buffers_to_assign_; + const SequentialHloOrdering::HloModuleSequence* module_sequence_; + + // In addition to Alloc and Free, the heap simulator exposes a concept of + // buffer sharing. When ShareBuffer is called, instead of allocating new + // space for the buffer, it associates the buffer with a previously allocated + // (or shared) buffer. Each group of mutually-shared buffers points to a + // single SharedGroup instance, which is a shared control block. + // + // This forced buffer sharing is hidden from the underlying heap algorithm, + // which only sees a regular Alloc call on the canonical buffer. The + // corresponding Free call is delayed until the liveness of all shared buffers + // in the group has expired, which is tracked via the refcount. The results + // are post-processed in Finish to add chunks for shared buffers. + // + // The shared_buffers_ map associates each shared buffer (including the + // canonical) to its SharedGroup control block. + struct SharedGroup { + const LogicalBuffer* canonical = nullptr; + int64 refcount = 0; + }; + tensorflow::gtl::FlatMap> + shared_buffers_; + + // Hold some sets for error-checking the sequence of Alloc and Free calls. + tensorflow::gtl::FlatSet allocated_buffers_; + tensorflow::gtl::FlatSet freed_buffers_; + + // Debugging information filled in while the heap simulator runs. + HeapSimulatorTrace debug_trace_; +}; + +// Abstract base class describing a heap simulation algorithm that assigns +// offsets to buffers. A sequence of Alloc / Free calls will be made, with the +// same semantics as a regular memory heap. Finish will be called at the end to +// collect the simulation results. +class HeapAlgorithm { + public: + using Chunk = HeapSimulator::Chunk; + using Result = HeapSimulator::Result; + + virtual ~HeapAlgorithm() = default; + + // Alloc allocates a buffer of 'size' bytes. + virtual void Alloc(const LogicalBuffer* buffer, int64 size) = 0; + + // Free de-allocates a previously allocated buffer. + virtual void Free(const LogicalBuffer* buffer, int64 size) = 0; + + // Finish collects the buffer offset assignment results. Free may only be + // called once, after the Alloc and Free calls. + virtual Result Finish() = 0; +}; + +// NoFragmentationStatsHeap computes the heap size assuming no fragmentation; +// this is the absolute minimum size for a given instruction sequence. The +// result.chunk_map returned in Finish is always empty, since we only collect +// stats, and don't actually compute chunk assignments. +class NoFragmentationStatsHeap : public HeapAlgorithm { + public: + NoFragmentationStatsHeap() = default; + ~NoFragmentationStatsHeap() override = default; + + void Alloc(const LogicalBuffer* buffer, int64 size) override; + void Free(const LogicalBuffer* buffer, int64 size) override; + Result Finish() override; + + private: + int64 current_heap_size_ = 0; + int64 max_heap_size_ = 0; +}; + +// DecreasingSizeRunsHeap collects runs of Alloc and Free calls, sorts them by +// decreasing size, and delegates the actual calls to another heap algorithm. +// This greedy heuristic tends to reduce fragmentation for all algorithms. +class DecreasingSizeRunsHeap : public HeapAlgorithm { + public: + DecreasingSizeRunsHeap(std::unique_ptr algorithm) + : algorithm_(std::move(algorithm)) {} + ~DecreasingSizeRunsHeap() override {} + + void Alloc(const LogicalBuffer* buffer, int64 size) override; + void Free(const LogicalBuffer* buffer, int64 size) override; + Result Finish() override; + + private: + // A single Alloc or Free operation that we've buffered in run_. + struct Op { + const LogicalBuffer* buffer; + int64 size; + }; + + // Current collection mode; kInit means no ops have been collected yet. + enum Mode { kInit, kAlloc, kFree }; + + void SetMode(Mode mode); + void CallAndDrainRun(); + + const std::unique_ptr algorithm_; + std::vector run_; + Mode mode_ = kInit; +}; + +// LazyBestFitHeap is a variant of the traditional best-fit heap. This is a +// greedy heuristic, based on the idea that delaying offset assignment helps +// reduce fragmentation. Here's an example of a "bad" offset assignment, where +// a tiny buffer A prevents adjacent free chunks from being coalesced: +// BAD: | free |A| free | +// If we could have delayed the assignment of A, we might have ended up with: +// GOOD: | free |A| +// +// In general it's actually hard to say whether GOOD is better than BAD; the +// heuristic we use is we try to leave large contiguous chunks free, and we try +// to avoid growing the overall heap size unless necessary. +// +// Just like regular best-fit, in Alloc we look for the smallest free chunk that +// fits the requested size. Unlike regular best-fit, we postpone offset +// assignment for buffers that cannot re-use existing free chunks (and force us +// to grow the heap); these buffers are "lazily" assigned offsets in Free. +class LazyBestFitHeap : public HeapAlgorithm { + public: + LazyBestFitHeap(int64 alignment) : alignment_(alignment) {} + ~LazyBestFitHeap() override {} + + void Alloc(const LogicalBuffer* buffer, int64 size) override; + void Free(const LogicalBuffer* buffer, int64 size) override; + Result Finish() override; + + private: + // Sentry value used to indicate a chunk that wasn't assigned an offset in + // Alloc, and will instead be assigned an offset in Free. + enum { kLazyAllocOffset = -1 }; + + struct OrderChunkByIncreasingSize { + bool operator()(const Chunk& a, const Chunk& b) { + if (a.size != b.size) return a.size < b.size; + return a.offset < b.offset; + } + }; + + void AddFreeChunk(int64 offset, int64 size); + + const int64 alignment_; + Result result_; + + // Maintain the set of free chunks, ordered by increasing size. + std::set free_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_ diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc new file mode 100644 index 00000000000..60a0768a86b --- /dev/null +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -0,0 +1,849 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/heap_simulator.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { +namespace { + +const char kAlloc[] = "Alloc"; +const char kFree[] = "Free"; +const char kFinish[] = "Finish"; + +// CallSequence records a sequence of Alloc/Free/Finish calls. +using CallSequence = std::vector>; + +// HeapCallRecorder is a dummy heap algorithm that simply records its calls. +class HeapCallRecorder : public HeapAlgorithm { + public: + explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {} + ~HeapCallRecorder() override {} + + void Alloc(const LogicalBuffer* buffer, int64 size) override { + calls_->emplace_back(kAlloc, buffer); + // Instead of assigning a real offset, we set the cardinality of the Alloc + // call. This isn't a valid assignment, but allows us to easily test for + // buffer sharing. + const int64 offset = result_.chunk_map.size(); + result_.chunk_map.emplace(buffer, Chunk{offset, size}); + } + void Free(const LogicalBuffer* buffer, int64 size) override { + calls_->emplace_back(kFree, buffer); + } + Result Finish() override { + calls_->emplace_back(kFinish, nullptr); + return result_; + } + + private: + CallSequence* calls_; + Result result_; +}; + +// HeapSimulatorTracker runs the heap simulator, recording the sequence of calls +// made to the underlying heap algorithm. Tests compare the actual call +// sequence against an expected sequence. +class HeapSimulatorTracker { + public: + // Constructor for testing a single entry computation. + HeapSimulatorTracker( + const string& name, std::unique_ptr computation, + const std::vector& instruction_sequence) { + module_ = MakeUnique(name); + module_->AddEntryComputation(std::move(computation)); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + // Since we're only tracking the sequence of Alloc/Free calls, the actual + // size of the buffers doesn't matter, so we always return 0. We rely on + // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by + // buffer id, for determinism in the tests. + auto zero_size = [](const LogicalBuffer& buffer) { return 0; }; + auto algorithm = MakeUnique( + MakeUnique(&actual_calls_)); + result_ = HeapSimulator::Run( + std::move(algorithm), *module_->entry_computation(), + instruction_sequence, *points_to_analysis_, zero_size) + .ConsumeValueOrDie(); + } + + explicit HeapSimulatorTracker(const string& name) { + module_ = MakeUnique(name); + } + + // Similar to the single entry computation constructor above, but runs the + // simulation over the entire module. + void RunWholeModule( + const std::vector& full_module_sequence) { + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + + // Construct the module sequence grouped by computation. + SequentialHloOrdering::HloModuleSequence module_sequence; + tensorflow::gtl::FlatMap reverse_position; + for (int i = 0; i < full_module_sequence.size(); ++i) { + const HloInstruction* instruction = full_module_sequence[i]; + module_sequence[instruction->parent()].push_back(instruction); + reverse_position[instruction] = full_module_sequence.size() - i; + } + + // Hack the size_fn so that it returns a decreasing value as we step through + // the sequence. This lets us ensure the Alloc calls are in the sequence + // order. The Free calls are sorted by LogicalBuffer.id, which is at least + // deterministic. + auto size_fn = [&reverse_position](const LogicalBuffer& buffer) { + return reverse_position[buffer.instruction()]; + }; + auto algorithm = MakeUnique( + MakeUnique(&actual_calls_)); + result_ = HeapSimulator::Run(std::move(algorithm), *module_, + module_sequence, *points_to_analysis_, size_fn) + .ConsumeValueOrDie(); + } + + HloModule* module() { return module_.get(); } + + // Returns the buffer defined at the given instruction and index. + const LogicalBuffer* BufferAt(const HloInstruction* instruction, + const ShapeIndex& index) const { + return points_to_analysis_->GetBufferDefinedAt(instruction, index) + .ConsumeValueOrDie(); + } + + // Ensures the expected sequence of Alloc/Free/Finish calls was performed. + void ExpectCallSequence(const CallSequence& expected) const { + EXPECT_EQ(expected, actual_calls_); + } + + // Ensures the buffers defined by the respective (instruction,index) pairs are + // shared, relying on the unique offsets assigned in HeapCallRecorder::Alloc. + void ExpectSharedBuffers(const HloInstruction* instruction_a, + const ShapeIndex& index_a, + const HloInstruction* instruction_b, + const ShapeIndex& index_b) { + const LogicalBuffer* a = BufferAt(instruction_a, index_a); + const LogicalBuffer* b = BufferAt(instruction_b, index_b); + EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset) + << *a << ", " << *b; + } + + private: + std::unique_ptr module_; + std::unique_ptr points_to_analysis_; + CallSequence actual_calls_; + HeapSimulator::Result result_; +}; + +class HeapSimulatorTest : public HloTestBase { + protected: + HeapSimulatorTest() {} + ~HeapSimulatorTest() override {} + + // Shapes for use in the examples. + Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {}); + Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4}); +}; + +TEST_F(HeapSimulatorTest, ScalarConstant) { + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + + // Constants aren't assigned. See b/32248867 + HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0}); + tracker.ExpectCallSequence({{kFinish, nullptr}}); +} + +TEST_F(HeapSimulatorTest, OneParam) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "param0")); + + // A single parameter which is also the output. + HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(param0, {})}, + {kFree, tracker.BufferAt(param0, {})}, + {kFinish, nullptr}, + }); +} + +TEST_F(HeapSimulatorTest, Multiply) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + + // We must keep all parameters and outputs. + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramX, mul}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramX, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramX, {})}, + {kFree, tracker.BufferAt(mul, {})}, + {kFinish, nullptr}, + }); +} + +TEST_F(HeapSimulatorTest, MultiplyAdd) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec4_, "paramY")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); + + // The buffer for add is the output, and it's shared with the buffer for mul. + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramX, mul, paramY, add}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramX, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(paramY, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramX, {})}, + {kFree, tracker.BufferAt(mul, {})}, + {kFree, tracker.BufferAt(paramY, {})}, + {kFinish, nullptr}, + }); + tracker.ExpectSharedBuffers(add, {}, mul, {}); +} + +TEST_F(HeapSimulatorTest, MultiplyDot) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32scalar_, "paramY")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + + // The buffer for dot is the output, and it cannot be shared with the buffer + // for mul, since dot isn't elementwise. + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramX, mul, paramY, dot}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramX, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(dot, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramX, {})}, + {kFree, tracker.BufferAt(mul, {})}, + {kFree, tracker.BufferAt(paramY, {})}, + {kFree, tracker.BufferAt(dot, {})}, + {kFinish, nullptr}, + }); +} + +TEST_F(HeapSimulatorTest, MultiplyDotAdd) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32scalar_, "paramY")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); + + // The buffer for add is the output, and it's shared with the buffer for dot. + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramX, mul, paramY, dot, add}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramX, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(dot, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramX, {})}, + {kFree, tracker.BufferAt(mul, {})}, + {kFree, tracker.BufferAt(paramY, {})}, + {kFree, tracker.BufferAt(dot, {})}, + {kFinish, nullptr}, + }); + tracker.ExpectSharedBuffers(add, {}, dot, {}); +} + +TEST_F(HeapSimulatorTest, MultiplyDotDot) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32scalar_, "paramY")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto dot0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + auto dot1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + + // The buffer for dot1 is the output. No buffers can be shared. The buffer + // for mul is freed before the end, since it's no longer used after dot0 + // finishes. + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramX, mul, paramY, dot0, dot1}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramX, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(dot0, {})}, + {kFree, tracker.BufferAt(mul, {})}, // mul no longer used + {kAlloc, tracker.BufferAt(dot1, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramX, {})}, + {kFree, tracker.BufferAt(paramY, {})}, + {kFree, tracker.BufferAt(dot0, {})}, + {kFree, tracker.BufferAt(dot1, {})}, + {kFinish, nullptr}, + }); +} + +TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32scalar_, "paramY")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto dot0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + auto dot1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); + + // The buffers for dot0, dot1 and tuple are the output. No buffers can be + // shared. The buffer for mul is freed before the end, since it's no longer + // used after dot0 finishes. + HeapSimulatorTracker tracker( + TestName(), builder.Build(), + {paramA, paramX, mul, paramY, dot0, dot1, tuple}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramX, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(dot0, {})}, + {kFree, tracker.BufferAt(mul, {})}, // mul no longer used + {kAlloc, tracker.BufferAt(dot1, {})}, + {kAlloc, tracker.BufferAt(tuple, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramX, {})}, + {kFree, tracker.BufferAt(paramY, {})}, + {kFree, tracker.BufferAt(dot0, {})}, + {kFree, tracker.BufferAt(dot1, {})}, + {kFree, tracker.BufferAt(tuple, {})}, + {kFinish, nullptr}, + }); +} + +TEST_F(HeapSimulatorTest, WholeModule) { + HeapSimulatorTracker tracker(TestName()); + + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + tracker.module()->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + tracker.module()->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, param)); + tracker.module()->AddEntryComputation(builder.Build()); + + tracker.RunWholeModule( + {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt}); + tracker.ExpectCallSequence({ + // The entry computation param and while_op are allocated first. + {kAlloc, tracker.BufferAt(param, {})}, + {kAlloc, tracker.BufferAt(param, {0})}, + {kAlloc, tracker.BufferAt(param, {1})}, + {kAlloc, tracker.BufferAt(while_op, {})}, + {kAlloc, tracker.BufferAt(while_op, {0})}, + {kAlloc, tracker.BufferAt(while_op, {1})}, + + // Now the while body param is allocated and freed. + {kAlloc, tracker.BufferAt(body_param, {})}, + {kAlloc, tracker.BufferAt(body_param, {0})}, + {kAlloc, tracker.BufferAt(body_param, {1})}, + {kFree, tracker.BufferAt(body_param, {})}, + {kFree, tracker.BufferAt(body_param, {0})}, + {kFree, tracker.BufferAt(body_param, {1})}, + + // Now the while cond param is allocated. The GTE instructions just alias + // the param elements, so the param tuple can immediately be freed. + {kAlloc, tracker.BufferAt(cond_param, {})}, + {kAlloc, tracker.BufferAt(cond_param, {0})}, + {kAlloc, tracker.BufferAt(cond_param, {1})}, + {kFree, tracker.BufferAt(cond_param, {})}, + + // Now the final cond less-than buffer is allocated. + {kAlloc, tracker.BufferAt(cond_lt, {})}, + + // The order of the remaining Free calls is based on the LogicalBuffer.id, + // which is deterministic, but not obvious. + {kFree, tracker.BufferAt(param, {})}, + {kFree, tracker.BufferAt(param, {0})}, + {kFree, tracker.BufferAt(param, {1})}, + + {kFree, tracker.BufferAt(while_op, {})}, + {kFree, tracker.BufferAt(while_op, {0})}, + {kFree, tracker.BufferAt(while_op, {1})}, + + {kFree, tracker.BufferAt(cond_param, {0})}, + {kFree, tracker.BufferAt(cond_param, {1})}, + {kFree, tracker.BufferAt(cond_lt, {})}, + + {kFinish, nullptr}, + }); +} + +// Base class for heap algorithm tests. +class HeapAlgorithmTestBase : public ::testing::Test { + protected: + HeapAlgorithmTestBase() { + buffer_a_ = DummyLogicalBuffer(); + buffer_b_ = DummyLogicalBuffer(); + buffer_c_ = DummyLogicalBuffer(); + buffer_d_ = DummyLogicalBuffer(); + buffer_e_ = DummyLogicalBuffer(); + buffer_f_ = DummyLogicalBuffer(); + buffer_g_ = DummyLogicalBuffer(); + buffer_h_ = DummyLogicalBuffer(); + buffer_i_ = DummyLogicalBuffer(); + } + ~HeapAlgorithmTestBase() override {} + + const LogicalBuffer* buffer_a_; + const LogicalBuffer* buffer_b_; + const LogicalBuffer* buffer_c_; + const LogicalBuffer* buffer_d_; + const LogicalBuffer* buffer_e_; + const LogicalBuffer* buffer_f_; + const LogicalBuffer* buffer_g_; + const LogicalBuffer* buffer_h_; + const LogicalBuffer* buffer_i_; + + private: + // Create a dummy LogicalBuffer to pass to the heap algorithm. Since the + // algorithms only use the buffer as a handle, we don't need to fill in much + // other than the id and color. + const LogicalBuffer* DummyLogicalBuffer() { + const LogicalBuffer::Id id = buffers_.size(); + buffers_.emplace_back(MakeUnique(nullptr, ShapeIndex{}, id, + LogicalBuffer::Color(0))); + return buffers_.back().get(); + } + + std::vector> buffers_; +}; + +class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; + +TEST_F(NoFragmentationStatsHeapTest, Empty) { + NoFragmentationStatsHeap heap; + EXPECT_EQ(0, heap.Finish().heap_size); +} + +TEST_F(NoFragmentationStatsHeapTest, Simple) { + NoFragmentationStatsHeap heap; + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 30); + heap.Alloc(buffer_d_, 30); + heap.Free(buffer_a_, 10); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 30); + heap.Free(buffer_d_, 30); + EXPECT_EQ(90, heap.Finish().heap_size); +} + +TEST_F(NoFragmentationStatsHeapTest, Mixed) { + NoFragmentationStatsHeap heap; + heap.Alloc(buffer_a_, 10); // max: A + + heap.Alloc(buffer_b_, 20); // max: A+B + heap.Free(buffer_b_, 20); + + heap.Alloc(buffer_c_, 30); // max: A+C + heap.Free(buffer_c_, 30); + + heap.Alloc(buffer_d_, 5); // max: A+C + heap.Free(buffer_d_, 5); + + heap.Free(buffer_a_, 10); + EXPECT_EQ(40, heap.Finish().heap_size); +} + +class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {}; + +TEST_F(DecreasingSizeRunsHeapTest, Empty) { + CallSequence call_sequence; + DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + heap.Finish(); + EXPECT_EQ(call_sequence, CallSequence({ + {kFinish, nullptr}, + })); +} + +TEST_F(DecreasingSizeRunsHeapTest, Simple) { + CallSequence call_sequence; + DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 30); + heap.Alloc(buffer_d_, 30); + heap.Free(buffer_a_, 10); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 30); + heap.Free(buffer_d_, 30); + heap.Finish(); + // Runs of Allocs and Frees are sorted by decreasing size, with buffer id + // tiebreaker. + EXPECT_EQ(call_sequence, CallSequence({ + {kAlloc, buffer_c_}, + {kAlloc, buffer_d_}, + {kAlloc, buffer_b_}, + {kAlloc, buffer_a_}, + {kFree, buffer_c_}, + {kFree, buffer_d_}, + {kFree, buffer_b_}, + {kFree, buffer_a_}, + {kFinish, nullptr}, + })); +} + +TEST_F(DecreasingSizeRunsHeapTest, Mixed) { + CallSequence call_sequence; + DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Free(buffer_b_, 20); + + heap.Alloc(buffer_c_, 30); + heap.Free(buffer_c_, 30); + + heap.Alloc(buffer_d_, 5); + heap.Free(buffer_d_, 5); + heap.Free(buffer_a_, 10); + heap.Finish(); + // Runs of Allocs and Frees are sorted by decreasing size. + EXPECT_EQ(call_sequence, CallSequence({ + {kAlloc, buffer_b_}, + {kAlloc, buffer_a_}, + {kFree, buffer_b_}, + + {kAlloc, buffer_c_}, + {kFree, buffer_c_}, + + {kAlloc, buffer_d_}, + {kFree, buffer_a_}, + {kFree, buffer_d_}, + {kFinish, nullptr}, + })); +} + +class LazyBestFitHeapTest : public HeapAlgorithmTestBase {}; + +TEST_F(LazyBestFitHeapTest, Empty) { + LazyBestFitHeap heap(/*alignment=*/1); + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(0, result.heap_size); + EXPECT_EQ(0, result.chunk_map.size()); +} + +TEST_F(LazyBestFitHeapTest, Simple) { + LazyBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 30); + heap.Alloc(buffer_d_, 30); + heap.Free(buffer_a_, 10); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 30); + heap.Free(buffer_d_, 30); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(90, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(10, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(LazyBestFitHeapTest, Mixed) { + LazyBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); // A lazy offset + + heap.Alloc(buffer_b_, 20); // B lazy offset + heap.Free(buffer_b_, 20); // B range = [0, 20) free = [0, 20) + + heap.Alloc(buffer_c_, 30); // C range = [0, 30) + heap.Free(buffer_c_, 30); // free = [0, 30) + + heap.Alloc(buffer_d_, 5); // D range = [0, 5) free = [5, 30) + heap.Free(buffer_d_, 5); // free = [0, 30) + + heap.Free(buffer_a_, 10); // A range = [30, 10) free = [0, 40) + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(40, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(5, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(LazyBestFitHeapTest, BestFit) { + LazyBestFitHeap heap(/*alignment=*/1); + + // First alloc/free buffer_a_, to force a big free chunk to appear. + heap.Alloc(buffer_a_, 200); // A lazy offset + heap.Free(buffer_a_, 200); // A range = [0, 200) free = [0, 200) + + // Now alloc a bunch of buffers that are allocated out of the free chunk. + heap.Alloc(buffer_b_, 30); // B range = [0, 30) free = [30, 200) + heap.Alloc(buffer_c_, 30); // C range = [30, 60) free = [60, 200) + heap.Alloc(buffer_d_, 20); // D range = [60, 80) free = [80, 200) + heap.Alloc(buffer_e_, 20); // E range = [80, 100) free = [100, 200) + heap.Alloc(buffer_f_, 10); // F range = [100, 110) free = [110, 200) + heap.Alloc(buffer_g_, 10); // G range = [110, 120) free = [120, 200) + heap.Alloc(buffer_h_, 80); // H range = [120, 200) + + // Free buffers to create free chunks of different sizes. + heap.Free(buffer_c_, 30); // free = [30, 60) + heap.Free(buffer_e_, 20); // free = [30, 60), [80, 100) + heap.Free(buffer_g_, 10); // free = [30, 60), [80, 100), [110, 120) + + // The best fit is picked out of the existing free chunks. + heap.Alloc(buffer_i_, 15); // I range = [80, 95) + + // The frees here ensure the buffer-coalescing logic is exercised. + heap.Free(buffer_b_, 30); + heap.Free(buffer_d_, 20); + heap.Free(buffer_f_, 10); + heap.Free(buffer_h_, 80); + heap.Free(buffer_i_, 15); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(200, result.heap_size); + EXPECT_EQ(200, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_e_).size); + EXPECT_EQ(10, result.chunk_map.at(buffer_f_).size); + EXPECT_EQ(10, result.chunk_map.at(buffer_g_).size); + EXPECT_EQ(80, result.chunk_map.at(buffer_h_).size); + EXPECT_EQ(15, result.chunk_map.at(buffer_i_).size); + + EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset); + EXPECT_EQ(80, result.chunk_map.at(buffer_e_).offset); + EXPECT_EQ(100, result.chunk_map.at(buffer_f_).offset); + EXPECT_EQ(110, result.chunk_map.at(buffer_g_).offset); + EXPECT_EQ(120, result.chunk_map.at(buffer_h_).offset); + EXPECT_EQ(80, result.chunk_map.at(buffer_i_).offset); +} + +TEST_F(LazyBestFitHeapTest, Lazy) { + LazyBestFitHeap heap(/*alignment=*/1); + + // First alloc some buffers, which are all lazily allocated offsets. + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 5); + heap.Alloc(buffer_c_, 10); + + // Now free some buffers, which forces offset assignment. + heap.Free(buffer_a_, 10); // A range = [0, 10) free = [0, 10) + heap.Free(buffer_c_, 10); // C range = [10, 20) free = [0, 20) + + // If we hadn't lazily assigned offsets, the free chunk wouldn't be large + // enough to hold the entire allocation. + heap.Alloc(buffer_d_, 20); // D range = [0, 20) + + heap.Free(buffer_b_, 5); // B range = [20, 25) + heap.Free(buffer_d_, 20); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(25, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(LazyBestFitHeapTest, ReuseLastFreeChunk) { + LazyBestFitHeap heap(/*alignment=*/1); + + // First alloc/free buffer_a_, to force a big free chunk to appear. + heap.Alloc(buffer_a_, 60); // A lazy offset + heap.Free(buffer_a_, 60); // A range = [0, 60) free = [0, 60) + + // Now alloc a bunch of buffers that are allocated out of the free chunk. + heap.Alloc(buffer_b_, 10); // B range = [0, 10) free = [10, 60) + heap.Alloc(buffer_c_, 20); // C range = [10, 30) free = [30, 60) + heap.Alloc(buffer_d_, 30); // D range = [30, 60) + + // Free buffers to create free chunks of different sizes. + heap.Free(buffer_b_, 10); // free = [0, 10) + heap.Free(buffer_d_, 30); // free = [0, 10), [30, 60) + + // No free chunks are large enough, but the last free chunk is adjacent to the + // end of the heap, so we re-use that chunk. + heap.Alloc(buffer_e_, 40); // E range = [30, 70) + + heap.Free(buffer_c_, 20); + heap.Free(buffer_e_, 40); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(70, result.heap_size); + EXPECT_EQ(60, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(10, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_e_).size); + + EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(30, result.chunk_map.at(buffer_d_).offset); + EXPECT_EQ(30, result.chunk_map.at(buffer_e_).offset); +} + +TEST_F(LazyBestFitHeapTest, Alignment) { + LazyBestFitHeap heap(/*alignment=*/64); + + // First alloc some buffers, which are all lazily allocated offsets. + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 5); + heap.Alloc(buffer_c_, 10); + + // Now free some buffers, which forces offset assignment with alignment. + heap.Free(buffer_a_, 10); // A range = [0, 10) free = [0, 10) + heap.Free(buffer_c_, 10); // C range = [64, 74) free = [0, 74) + + // If we hadn't lazily assigned offsets, and accounted for alignment, the free + // chunk wouldn't be large enough to hold the entire allocation. + heap.Alloc(buffer_d_, 74); // D range = [0, 74) free = [) + + heap.Free(buffer_b_, 5); // B range = [128, 133) free = [74, 133) + heap.Alloc(buffer_e_, 23); // E range = [128, 151) free = [74, 128) + + heap.Free(buffer_d_, 74); // free = [0, 128) + heap.Free(buffer_e_, 23); // free = [0, 151) + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(151, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(74, result.chunk_map.at(buffer_d_).size); + EXPECT_EQ(23, result.chunk_map.at(buffer_e_).size); + + EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(128, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(64, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); + EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto new file mode 100644 index 00000000000..af853385d63 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// DO NOT USE THESE PROTO MESSAGES FOR ANYTHING OTHER THAN DEBUGGING. +// +// Don't use these protos in the real compilation or execution codepaths. The +// data format is meant for debugging only, and may change without notice. +// +// Many of the protos below are simple 1-to-1 serializations of the +// corresponding C++ classes. +// +// FIELD NAMES ARE IMPORTANT +// +// Unlike most protos, you can't safely change the names of fields, even if you +// keep the numeric ids the same. This is because we sometimes serialize these +// protos as JSON, which includes the field names in the serialization. + +syntax = "proto3"; + +package xla; +import "tensorflow/compiler/xla/xla_data.proto"; + +option cc_enable_arenas = true; + +// Serialization of HloInstruction. +message HloInstructionProto { + string name = 1; + string opcode = 2; + xla.Shape shape = 3; + repeated string operand_names = 4; + repeated string control_predecessor_names = 5; + repeated string called_computation_names = 6; + + xla.OpMetadata metadata = 7; + + // Literal, only present for kConstant. + xla.LiteralProto literal = 8; + + // Parameter info, only present for kParameter. + int64 parameter_number = 9; + string parameter_name = 10; + + // Fusion state, only present for kFusion. + string fusion_kind = 11; + HloComputationProto fused_instructions_computation = 12; + + // Index for kGetTupleElement. + int64 tuple_index = 13; +} + +// Serialization of HloComputation. +message HloComputationProto { + string name = 1; + + // The array of instructions is always in a valid dependency order, where + // operands appear before their users. + repeated HloInstructionProto instructions = 2; +} + +// Serialization of HloModule. +message HloModuleProto { + string name = 1; + string entry_computation_name = 2; + + // The array of computations is always in a valid dependency order, where + // callees appear before their callers. + repeated HloComputationProto computations = 3; +} + +// Serialization of HloOrdering. +message HloOrderingProto { + // NOTE: currently only sequential orderings are serialized. + message SequentialComputation { + string computation_name = 1; + repeated string instruction_names = 2; + } + repeated SequentialComputation sequential_computations = 1; +} + +// Serialization of LogicalBuffer. +message LogicalBufferProto { + // Location represents an instruction and its shape index, which uniquely + // identifies a point where a buffer is needed. + message Location { + // NOTE: module_name isn't necessary, since all LogicalBuffers are + // associated with a single HloModule. + string computation_name = 1; + string instruction_name = 2; + repeated int64 shape_index = 3; + } + + int64 id = 1; + int64 size = 2; + + // The location where the buffer is defined. + Location defined_at = 3; + + int64 color = 4; +} + +// Serialization of BufferAllocation. +message BufferAllocationProto { + // Assigned represents a single LogicalBuffer that is assigned to this + // BufferAllocation. + message Assigned { + int64 logical_buffer_id = 1; + int64 offset = 2; + int64 size = 3; + } + + int64 index = 1; + int64 size = 2; + bool is_thread_local = 3; + bool is_reusable = 4; + bool is_entry_computation_parameter = 5; + int64 parameter_number = 6; + bool maybe_live_out = 7; + int64 color = 8; + repeated Assigned assigned = 9; +} + +// A trace of a HeapSimulator run. +message HeapSimulatorTrace { + // The trace includes a list of events, where each event describes one action + // performed by the heap simulator. + message Event { + enum Kind { + ALLOC = 0; // A memory region was allocated for the buffer. + FREE = 1; // A memory region was freed for the buffer. + + // A buffer was shared with another (canonical) buffer. This is similar to + // ALLOC, except that instead of allocating a new region of memory, the + // memory region of the canonical buffer is directly re-used. Multiple + // buffers may share with the same canonical buffer. The lifetime of the + // canonical buffer is extended to the union of all lifetimes. + SHARE_WITH = 2; + } + Kind kind = 1; + + // The id of the LogicalBuffer that the event applies to. + int64 buffer_id = 2; + + // The HloInstruction that the simulation was processing that caused this + // event to occur, identified by its computation and instruction name. E.g. + // buffers defined by instruction A are allocated when processing A. + string computation_name = 3; + string instruction_name = 4; + + // The id of the canonical LogicalBuffer that the buffer shares with. Only + // set for SHARE_WITH events. + int64 share_with_canonical_id = 5; + } + repeated Event events = 1; + bool whole_module_simulation = 2; +} + +// Serialization of BufferAssignment. +message BufferAssignmentProto { + // Alias represents a source LogicalBuffer, and the buffer location that + // aliases it. + message BufferAlias { + int64 source_buffer_id = 1; + LogicalBufferProto.Location location = 2; + } + + repeated LogicalBufferProto logical_buffers = 1; + repeated BufferAlias buffer_aliases = 2; + repeated BufferAllocationProto buffer_allocations = 3; + repeated HeapSimulatorTrace heap_simulator_traces = 4; +} + +// Grouping message that contains all of the information above. +message HloProto { + HloModuleProto hlo_module = 1; + HloOrderingProto hlo_ordering = 2; + BufferAssignmentProto buffer_assignment = 3; +} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc new file mode 100644 index 00000000000..3b37f4a4b89 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -0,0 +1,396 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +void HloBuffer::AddValue(const HloValue& value) { + // If the value is already contained in this buffer, just return. + if (std::find(value_ids_.begin(), value_ids_.end(), value.id()) != + value_ids_.end()) { + return; + } + + value_ids_.push_back(value.id()); + + // Add all of the locations of the HloValue to this buffer. + for (const HloLocation& location : value.locations()) { + if (std::find(locations_.begin(), locations_.end(), location) == + locations_.end()) { + locations_.push_back(location); + } + } +} + +bool HloBuffer::operator==(const HloBuffer& other) const { + bool equal = id() == other.id(); + if (equal) { + // DCHECK because these comparisons are expensive (linear time). + DCHECK(value_ids() == other.value_ids()); + DCHECK(locations() == other.locations()); + } + return equal; +} + +string HloBuffer::ToString() const { + return StrCat("HloBuffer ", id_, ", values: ", Join(value_ids_, ", ")); +} + +std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { + out << buffer.ToString(); + return out; +} + +void HloBufferSet::AddBuffer(HloBuffer::Id buffer_id) { + if (std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id) == + buffer_ids_.end()) { + buffer_ids_.push_back(buffer_id); + } +} + +void HloBufferSet::RemoveBufferOrDie(HloBuffer::Id buffer_id) { + auto it = std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id); + CHECK(it != buffer_ids_.end()); + buffer_ids_.erase(it); +} + +string HloBufferSet::ToString() const { + return StrCat("HloBufferSet, buffers: ", Join(buffer_ids_, ", ")); +} + +std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set) { + out << buffer_set.ToString(); + return out; +} + +bool InstructionBufferSet::IsAmbiguous() const { + bool is_ambiguous = false; + ForEachElement( + [&is_ambiguous](const ShapeIndex& index, const HloBufferSet& buffer_set) { + is_ambiguous |= buffer_set.buffer_ids().size() > 1; + }); + return is_ambiguous; +} + +bool InstructionBufferSet::IsDistinct() const { + bool is_distinct = true; + tensorflow::gtl::FlatSet seen_ids; + ForEachElement([&is_distinct, &seen_ids](const ShapeIndex& index, + const HloBufferSet& buffer_set) { + for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) { + auto pair = seen_ids.insert(buffer_id); + if (!pair.second) { + is_distinct = false; + } + } + }); + return is_distinct; +} + +string InstructionBufferSet::ToString() const { + string out = + StrCat("InstructionBufferSet(", ShapeUtil::HumanString(shape()), ")\n"); + ForEachElement([this, &out](const ShapeIndex& index, + const HloBufferSet& value_set) { + StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); + }); + return out; +} + +std::ostream& operator<<(std::ostream& out, + const InstructionBufferSet& buffer_set) { + out << buffer_set.ToString(); + return out; +} + +HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {} + +void HloAliasAnalysis::InitializeBufferSets() { + std::unordered_map value_to_buffer; + + // Initially define a buffer for every HloValue in the module. + for (const HloValue* value : dataflow_analysis_->values()) { + HloBuffer& buffer = NewHloBuffer(); + buffer.AddValue(*value); + value_to_buffer[value->id()] = buffer.id(); + } + + // Construct the Instruction buffer set to contain the HloBuffers for each + // HloValue in the InstructionValueSet. + for (auto& computation : module_->computations()) { + for (auto& instruction : computation->instructions()) { + buffer_sets_.emplace(std::piecewise_construct, + std::forward_as_tuple(instruction.get()), + std::forward_as_tuple(instruction->shape())); + dataflow_analysis_->GetInstructionValueSet(instruction.get()) + .ForEachElement( + [this, &instruction, &value_to_buffer]( + const ShapeIndex& index, const HloValueSet& value_set) { + for (HloValue::Id value_id : value_set.value_ids()) { + HloBuffer::Id buffer_id = value_to_buffer.at(value_id); + GetBufferSet(instruction.get(), index).AddBuffer(buffer_id); + } + }); + } + } +} + +void HloAliasAnalysis::CombineBuffers( + tensorflow::gtl::ArraySlice buffer_ids) { + VLOG(4) << "Combining buffers: " << Join(buffer_ids, ", "); + + if (buffer_ids.size() < 2) { + return; + } + + // Merging buffers invalidates the buffer vector. + buffers_vector_.clear(); + + // Add all values from all buffers to the first buffer in the list. + HloBuffer& unified_buffer = GetBuffer(buffer_ids[0]); + for (int i = 1; i < buffer_ids.size(); ++i) { + const HloBuffer::Id buffer_id = buffer_ids[i]; + const HloBuffer& buffer = GetBuffer(buffer_id); + + VLOG(4) << "Eliminating buffer: " << buffer_id; + + // Add all values held by the buffer-to-eliminate to the unified buffer. + for (HloValue::Id value_id : buffer.value_ids()) { + unified_buffer.AddValue(dataflow_analysis_->GetValue(value_id)); + } + + // Iterate through all locations where the buffer-to-eliminate exists and + // replace it with the unified buffer. + for (const HloLocation& location : buffer.locations()) { + VLOG(4) << "Replacing in " << location; + GetBufferSet(location.instruction, location.index) + .RemoveBufferOrDie(buffer_id); + GetBufferSet(location.instruction, location.index) + .AddBuffer(unified_buffer.id()); + } + + buffers_.erase(buffer_id); + } + + TF_DCHECK_OK(Verify()); +} + +Status HloAliasAnalysis::Verify() const { + // Verify every HloBuffer in buffers_ exists somewhere in an HloBufferSet and + // verify that every HloBuffer in the HloBufferSets exists somewhere in + // buffers_. + tensorflow::gtl::FlatSet buffers_in_sets; + for (auto& pair : buffer_sets_) { + const InstructionBufferSet& instruction_buffer_set = pair.second; + TF_RETURN_IF_ERROR(instruction_buffer_set.ForEachElementWithStatus( + [this, &buffers_in_sets](const ShapeIndex& index, + const HloBufferSet& buffer_set) -> Status { + for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) { + TF_RET_CHECK(ContainsKey(buffers_, buffer_id)); + buffers_in_sets.insert(buffer_id); + } + return Status::OK(); + })); + } + for (auto& pair : buffers_) { + const HloBuffer::Id buffer_id = pair.first; + const HloBuffer& buffer = pair.second; + TF_RET_CHECK(buffer_id == buffer.id()); + TF_RET_CHECK(ContainsKey(buffers_in_sets, buffer_id)); + } + return Status::OK(); +} + +void HloAliasAnalysis::FlattenInstructionBufferSets( + tensorflow::gtl::ArraySlice instructions) { + VLOG(4) << "Flattening buffer sets of instructions: " + << Join(instructions, ", ", + [this](string* out, const HloInstruction* instruction) { + StrAppend(out, instruction->FullyQualifiedName()); + }); + if (instructions.size() < 2) { + return; + } + ShapeUtil::ForEachSubshape( + instructions[0]->shape(), + [this, instructions](const Shape& /*subshape*/, const ShapeIndex& index) { + // Gather all HloBuffers contained in all the buffer sets of the + // given instructions at the current index. + std::vector to_unify; + for (const HloInstruction* instruction : instructions) { + const HloBufferSet& buffer_set = GetBufferSet(instruction, index); + to_unify.insert(to_unify.end(), buffer_set.buffer_ids().begin(), + buffer_set.buffer_ids().end()); + } + // Sort and uniquify buffers to combine. + std::sort(to_unify.begin(), to_unify.end()); + to_unify.erase(std::unique(to_unify.begin(), to_unify.end()), + to_unify.end()); + + CombineBuffers(to_unify); + }); +} + +HloBuffer& HloAliasAnalysis::NewHloBuffer() { + HloBuffer::Id buffer_id = next_buffer_id_++; + auto it_added = buffers_.emplace(std::piecewise_construct, + std::forward_as_tuple(buffer_id), + std::forward_as_tuple(buffer_id)); + CHECK(it_added.second); + + return it_added.first->second; +} + +string HloAliasAnalysis::ToString() const { + string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); + StrAppend(&out, " Instruction buffer sets:\n"); + for (const std::unique_ptr& computation : + module_->computations()) { + for (const std::unique_ptr& instruction : + computation->instructions()) { + StrAppend(&out, " ", instruction->FullyQualifiedName(), ":\n"); + auto buffer_str = [this](const HloBuffer& buffer) { + return StrCat( + "Buffer ", buffer.id(), ", values: ", + Join(buffer.value_ids(), ", ", + [this](string* out, HloValue::Id value_id) { + StrAppend( + out, + dataflow_analysis_->GetValue(value_id).ToShortString()); + })); + }; + if (ShapeUtil::IsTuple(instruction->shape())) { + GetInstructionBufferSet(instruction.get()) + .ForEachElement([this, &out, &buffer_str]( + const ShapeIndex& index, + const HloBufferSet& buffer_set) { + StrAppend(&out, " tuple index ", index.ToString(), ":\n"); + for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) { + StrAppend(&out, " ", buffer_str(GetBuffer(buffer_id)), + "\n"); + } + }); + } else { + const HloBufferSet top_level_buffer_set = + GetBufferSet(instruction.get()); + for (HloBuffer::Id buffer_id : top_level_buffer_set.buffer_ids()) { + StrAppend(&out, " ", buffer_str(GetBuffer(buffer_id)), "\n"); + } + } + } + } + return out; +} + +const InstructionBufferSet& HloAliasAnalysis::GetInstructionBufferSet( + const HloInstruction* instruction) const { + return buffer_sets_.at(instruction); +} + +InstructionBufferSet& HloAliasAnalysis::GetInstructionBufferSet( + const HloInstruction* instruction) { + return buffer_sets_.at(instruction); +} + +const HloBufferSet& HloAliasAnalysis::GetBufferSet( + const HloInstruction* instruction, const ShapeIndex& index) const { + return buffer_sets_.at(instruction).element(index); +} + +HloBufferSet& HloAliasAnalysis::GetBufferSet(const HloInstruction* instruction, + const ShapeIndex& index) { + return *buffer_sets_.at(instruction).mutable_element(index); +} + +const std::vector& HloAliasAnalysis::buffers() const { + if (buffers_vector_.empty()) { + // Lazily construct vector of buffers. + buffers_vector_.reserve(buffers_.size()); + for (auto& pair : buffers_) { + buffers_vector_.push_back(&pair.second); + } + std::sort(buffers_vector_.begin(), buffers_vector_.end(), + [](const HloBuffer* a, const HloBuffer* b) { + return a->id() < b->id(); + }); + } else { + CHECK_EQ(buffers_vector_.size(), buffers_.size()); + for (const HloBuffer* buffer : buffers_vector_) { + DCHECK(ContainsKey(buffers_, buffer->id())); + DCHECK(&GetBuffer(buffer->id()) == buffer); + } + } + return buffers_vector_; +} + +/* static */ +StatusOr> HloAliasAnalysis::Run( + HloModule* module) { + VLOG(1) << "HloAliasAnalysis::Run on module " << module->name(); + XLA_VLOG_LINES(2, module->ToString()); + + auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); + TF_ASSIGN_OR_RETURN( + alias_analysis->dataflow_analysis_, + HloDataflowAnalysis::Run(module, /*ssa_form=*/true, + /*bitcast_defines_value=*/false)); + + alias_analysis->InitializeBufferSets(); + VLOG(3) << "Initial state:\n" << alias_analysis->ToString(); + + // The while instruction updates its state inplace, so the inputs to the while + // alias the while instruction, the parameters of the subcomputations, and the + // root of the body subcomputation. + for (auto& computation : module->computations()) { + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile) { + VLOG(4) << "Flattening buffer sets at kWhile instruction: " + << instruction->name(); + alias_analysis->FlattenInstructionBufferSets( + {instruction->operand(0), + instruction->while_body()->parameter_instruction(0), + instruction->while_body()->root_instruction(), + instruction->while_condition()->parameter_instruction(0), + instruction.get()}); + } + } + } + VLOG(1) << alias_analysis->ToString(); + return std::move(alias_analysis); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h new file mode 100644 index 00000000000..0fa35827b5e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -0,0 +1,301 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A container which can hold one or more HloValues. An HLO buffer abstractly +// represents the allocation which HLO instructions write into and read +// from. Generally there is a one-to-one correspondence between HloBuffers and +// HloValue where each HloValue in the module is held in a unique HloBuffer. An +// exception is the while instruction which updates the loop state in-place. In +// this case, we have a single HloBuffer for each HloLocation in the loop state, +// but multiple HloValues. For example: +// +// %init = ... +// %while = While(%init, body, condition) +// +// body: +// %body_param = Param(0) +// ... +// %body_root = ... +// +// condition: +// %cond_param = Param(0) +// ... +// +// For simplicity, assume that %while is array-shaped. In this case, we have a +// single HloBuffer which holds the following HloValues: HloValue{%init}, +// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and +// HloValue{%cond_param}. +// +// HloBuffers may appear at different HloLocations in the module mirroring the +// same propery of HloValues. For example: +// +// %sub = Sub(...) +// %add = Add(...) +// %tuple = Tuple(%add, %sub) +// %gte = GetTupleElement(%tuple, 0) +// +// In this case, the HloBuffer containing %add appears at the following +// locations: HloLocation{%add, {}}, HloLocation{%tuple, {0}}, and +// HloLocation{%gte, {}}. +// +// Different HloLocations which share the same HloBuffer indicate mandatory +// aliasing in the HLO module. These locations must share the same memory +// allocation for correctness (the backends rely on this property). This differs +// from incidental aliasing introduced by memory reuse in BufferAssignment where +// different instructions may happen to get the same allocation. +class HloBuffer { + public: + using Id = int64; + + HloBuffer(int64 id) : id_(id) {} + + // Return the unique identifier for this HloBuffer. + int64 id() const { return id_; } + + // Add a value to the set of values held by this buffer. Also adds the + // HloLocations of the value to the locations vector of the buffer. If the + // buffer already contains this value, then this method is a nop. + void AddValue(const HloValue& value); + + // Return the IDs of all values contained in this buffer. + const std::vector& value_ids() const { return value_ids_; } + + // Return the locations (output of which instruction and at what index) where + // the buffer is used. This is exactly the union of the locations of the + // HloValues contained by the buffer. + const std::vector& locations() const { return locations_; } + + string ToString() const; + + bool operator==(const HloBuffer& other) const; + bool operator!=(const HloBuffer& other) const { return !(*this == other); } + + private: + // Unique identifier for this HloBuffer. + const Id id_; + + // The set of values contained in the this buffer. + std::vector value_ids_; + + // The set of locations where this buffer is used. + std::vector locations_; +}; + +std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); + +// A class representing the set of possible HloBuffers at a particular +// HloLocation (shape index in the output of an instruction) in the XLA +// graph. In most cases, the buffer set will have a single HloBuffer indicating +// that the HloBuffer which appears at that particular location is known +// unambiguously at compile-time. However, tuple-shaped Select instructions can +// introduce ambiguity as the tuple elements of the operands are passed by +// reference into the output of the Select. For example: +// +// %pred = ... +// %tuple0 = Tuple(%a, %b) +// %tuple1 = Tuple(%x, %y) +// %select = Select(%pred, %tuple0, %tuple1) +// +// In this case the HloBufferSet at HloLocation{%select, {0}} contains the +// HloBuffer holding %a and the HloBuffer holding %x. +class HloBufferSet { + public: + HloBufferSet() = default; + + // Add the given buffer to this buffer set. If the buffer already exists in + // the set, then this is a NOP. + void AddBuffer(HloBuffer::Id buffer_id); + + // Removes the given buffer from this buffer set. CHECK fails in the buffer is + // not contained in this set. + void RemoveBufferOrDie(HloBuffer::Id buffer_id); + + // Returns the unique buffer in this set. CHECK fails if the set does not + // contain exactly one buffer. + HloBuffer::Id GetUniqueBufferId() const { + CHECK_EQ(buffer_ids().size(), 1); + return buffer_ids()[0]; + } + + // Returns the IDs of the HloBuffers contained in this buffer set. + const std::vector& buffer_ids() const { return buffer_ids_; } + + string ToString() const; + + private: + // The IDs of the HloBuffers containted in this buffer set. + std::vector buffer_ids_; +}; + +std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set); + +// A class collecting the HloBuffers in the output of an HLO instruction. For +// array-shaped instructions, an InstructionBufferSet trivially holds a single +// HloBufferSet. Tuple-shaped InstructionBufferSets hold multiple +// HloBufferSets. +class InstructionBufferSet : public ShapeTree { + public: + InstructionBufferSet(const Shape& shape) : ShapeTree(shape) {} + + // Returns true if any HloBufferSet contained in this InstructionBufferSet + // is not a singleton. + bool IsAmbiguous() const; + + // Returns true if any HloBuffer appears in more than one HloBufferSet + // contained in this InstructionBufferSet. + bool IsDistinct() const; + + string ToString() const; +}; + +std::ostream& operator<<(std::ostream& out, + const InstructionBufferSet& buffer_set); + +class HloAliasAnalysis { + public: + static StatusOr> Run(HloModule* module); + + string ToString() const; + + // Return the InstructionBufferSet for the given instruction. + const InstructionBufferSet& GetInstructionBufferSet( + const HloInstruction* instruction) const; + InstructionBufferSet& GetInstructionBufferSet( + const HloInstruction* instruction); + + // Return the HloBufferSet for the given location. + const HloBufferSet& GetBufferSet(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + HloBufferSet& GetBufferSet(const HloInstruction* instruction, + const ShapeIndex& index = {}); + + // Return the HloBuffer with the given ID. + const HloBuffer& GetBuffer(HloBuffer::Id buffer_id) const { + return buffers_.at(buffer_id); + } + HloBuffer& GetBuffer(HloBuffer::Id buffer_id) { + return buffers_.at(buffer_id); + } + + // Returns the unique buffer at the given location. CHECK fails if the buffer + // set at that location does not contain exactly one buffer. + const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + return GetBuffer(GetBufferSet(instruction, index).GetUniqueBufferId()); + } + HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) { + return GetBuffer(GetBufferSet(instruction, index).GetUniqueBufferId()); + } + + // Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This + // vector is lazily computed. Mutating operations on HloAliasAnalysis may + // invalidate the underlying vector requiring recomputation. + const std::vector& buffers() const; + + // Returns the underlying dataflow analysis used by this alias analysis. + const HloDataflowAnalysis& dataflow_analysis() const { + return *dataflow_analysis_; + } + + protected: + HloAliasAnalysis(HloModule* module); + + // Creates a new HloBuffer and returns a reference to it. + HloBuffer& NewHloBuffer(); + + // Construct the initial set of buffer sets where an HloBuffer is created for + // each HloValue in the module. + void InitializeBufferSets(); + + // Combine the InstructionBufferSets for given instructions. The HloBuffers in + // the HloBufferSets at each ShapeIndex are combined via CombineBuffers + // into a single HloBuffer. This single HloBuffer then becomes the only member + // of these HloBufferSets (ie, they become singletons). The HloBuffers + // which are removed from the buffer sets are deleted from the analysis. This + // flattening may change InstructionBufferSets of other instructions not in + // 'instructions' because the HloBuffers of the InstructionBufferSets of + // 'instructions' can be used elsewhere in the module. + // + // This method is used to enforce the mandatory aliasing of while instructions + // where the init operand, body parameter, condition parameter, body root + // instruction, and the while itself must have exactly the same HloBuffer at + // each ShapeIndex. + // + // Precondition: The shapes on the given instructions must be compatible. + void FlattenInstructionBufferSets( + tensorflow::gtl::ArraySlice instructions); + + // Combines the given HloBuffers into a single buffer. One of the given + // HloBuffers is chosen as the unified buffer, and all other references to the + // remaining buffers are replaced by this unified buffer. All HloValues + // contained in the replaced buffers are moved to the unified buffer, and the + // replaced buffers are deleted from the analysis. + void CombineBuffers(tensorflow::gtl::ArraySlice buffer_ids); + + // Verifies internal state of the analysis. + Status Verify() const; + + HloModule* module_; + + // The underlying dataflow analysis used by this alias analysis. + std::unique_ptr dataflow_analysis_; + + // The map of all HloBuffers in the module. + std::unordered_map buffers_; + + // A map from instruction to its InstructionBufferSet. + std::unordered_map buffer_sets_; + + // A lazily constructed vector containing all HloBuffers sorted by + // HloBuffer::Id. + mutable std::vector buffers_vector_; + + // The Id to use for the next HloBuffer. + int64 next_buffer_id_ = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc new file mode 100644 index 00000000000..24c467d411b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -0,0 +1,760 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +using ::testing::UnorderedElementsAre; + +class HloAliasAnalysisTest : public HloTestBase { + protected: + HloAliasAnalysisTest() : module_(TestName()) {} + + // Run alias analysis on the member module. For convenience returns a + // reference to the generated analysis stored in analysis_. + const HloAliasAnalysis& RunAnalysis() { + analysis_ = HloAliasAnalysis::Run(&module_).ConsumeValueOrDie(); + return *analysis_; + } + + // Return a vector of the buffers in the buffer set at the current location. + std::vector GetBuffersAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + std::vector buffers; + for (HloBuffer::Id buffer_id : + analysis_->GetBufferSet(instruction, index).buffer_ids()) { + buffers.push_back(analysis_->GetBuffer(buffer_id)); + } + return buffers; + } + + // Return a vector containing all of the HloValues in the given buffer. + std::vector GetValuesInBuffer(const HloBuffer& buffer) { + std::vector values; + for (HloValue::Id value_id : buffer.value_ids()) { + values.push_back(analysis_->dataflow_analysis().GetValue(value_id)); + } + return values; + } + + // Return the HloValue defined at the given location. + const HloValue& GetValueDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index); + } + + const HloValue& GetUniqueValueInBuffer(const HloBuffer& buffer) const { + CHECK_EQ(buffer.value_ids().size(), 1); + return analysis_->dataflow_analysis().GetValue(buffer.value_ids()[0]); + } + + HloModule module_; + std::unique_ptr analysis_; + + const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(HloAliasAnalysisTest, BinaryOperation) { + // Test the analysis on a single binary operation (Add). + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, constant1, constant2)); + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.buffers().size(), 3); + + // All of the buffer sets should trivially contain a single buffer containing + // a single value. + for (const HloInstruction* instruction : {constant1, constant2, add}) { + EXPECT_EQ(GetUniqueValueInBuffer(analysis.GetUniqueBufferAt(instruction)), + GetValueDefinedAt(instruction)); + } + + EXPECT_FALSE(analysis.GetInstructionBufferSet(add).IsAmbiguous()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(add).IsDistinct()); +} + +TEST_F(HloAliasAnalysisTest, TupleAndGtes) { + // Verify the analysis for a Tuple and GetTupleElement instructions. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.buffers().size(), 4); + + // Verify the expected aliasing of the tuple elements. + EXPECT_EQ( + GetUniqueValueInBuffer(analysis.GetUniqueBufferAt(tuple, /*index=*/{})), + GetValueDefinedAt(tuple, /*index=*/{})); + EXPECT_EQ( + GetUniqueValueInBuffer(analysis.GetUniqueBufferAt(tuple, /*index=*/{0})), + GetValueDefinedAt(param0)); + EXPECT_EQ( + GetUniqueValueInBuffer(analysis.GetUniqueBufferAt(tuple, /*index=*/{1})), + GetValueDefinedAt(param1)); + + // The tuple operand, tuple element, and result of the GTE instruction should + // all be the same buffer. + EXPECT_EQ(analysis.GetUniqueBufferAt(param0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(param0), + analysis.GetUniqueBufferAt(gte0)); + + // Verify the locations of an aliased buffer. + EXPECT_THAT( + analysis.GetUniqueBufferAt(param0).locations(), + UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}}, + HloLocation{gte0, {}})); + + EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); +} + +TEST_F(HloAliasAnalysisTest, NondistinctTuple) { + // Test a expression with a non-distinct buffer set. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + // param0 is included twice in the tuple. + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({param0, param1, param0})); + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_THAT( + analysis.GetUniqueBufferAt(param0).locations(), + UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}}, + HloLocation{tuple, {2}})); + + EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); + EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); +} + +TEST_F(HloAliasAnalysisTest, SingleCall) { + // Test a single call of a subcomputation. The subcomputation adds its two + // array-shaped parameters. + auto subbuilder = HloComputation::Builder("Subcomputation"); + auto subparam0 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto subparam1 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); + HloComputation* called_computation = + module_.AddEmbeddedComputation(subbuilder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto call = builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {constant1, constant2}, called_computation)); + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // Verify aliasing of the kCall operands and the subcomputation parameters. + EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(), + UnorderedElementsAre(HloLocation{constant1, {}}, + HloLocation{subparam0, {}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(), + UnorderedElementsAre(HloLocation{constant2, {}}, + HloLocation{subparam1, {}})); + + // The subcomputation root and the kCall itself should alias. + EXPECT_THAT( + analysis.GetUniqueBufferAt(add).locations(), + UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call, {}})); +} + +TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { + // Test a subcomputation which is called twice with different argument values. + auto subbuilder = HloComputation::Builder("Subcomputation"); + auto subparam0 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto subparam1 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); + HloComputation* called_computation = + module_.AddEmbeddedComputation(subbuilder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto call1 = builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {constant1, constant2}, called_computation)); + auto call2 = builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {call1, constant2}, called_computation)); + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(), + UnorderedElementsAre(HloLocation{constant1, {}}, + HloLocation{subparam0, {}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(), + UnorderedElementsAre(HloLocation{constant2, {}}, + HloLocation{subparam1, {}})); + + // The 'add' (root of the subcomputation) aliases the two call instruction, + // and the first parameter of the subcomputation because 'call1' it is passed + // as an argument to the subcomputation in 'call2'. + EXPECT_THAT( + analysis.GetUniqueBufferAt(add).locations(), + UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call1, {}}, + HloLocation{subparam0, {}}, HloLocation{call2, {}})); + + EXPECT_THAT(GetBuffersAt(subparam0), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(add))); + EXPECT_THAT(GetBuffersAt(subparam1), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant2))); + + EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam0).IsAmbiguous()); + EXPECT_FALSE(analysis.GetInstructionBufferSet(subparam1).IsAmbiguous()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam0).IsDistinct()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam1).IsDistinct()); +} + +TEST_F(HloAliasAnalysisTest, SingleWhile) { + // Test a simple single while instruction. The while body includes a + // pass-through value. HLO: + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %tuple = Tuple(%constant1, %constant2) + // return While(%tuple, body, condition) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + auto body_tuple = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + + // Condition computation trivially returns a constant "false". + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // Verify the locations of the aliased while buffers. + EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).locations(), + UnorderedElementsAre( + HloLocation{tuple, {}}, HloLocation{xla_while, {}}, + HloLocation{body_param, {}}, HloLocation{body_tuple, {}}, + HloLocation{cond_param, {}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).locations(), + UnorderedElementsAre( + HloLocation{constant1, {}}, HloLocation{tuple, {0}}, + HloLocation{xla_while, {0}}, HloLocation{body_param, {0}}, + HloLocation{body_element_0, {}}, HloLocation{body_tuple, {0}}, + HloLocation{cond_param, {0}})); + EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).locations(), + UnorderedElementsAre( + HloLocation{constant2, {}}, HloLocation{tuple, {1}}, + HloLocation{xla_while, {1}}, HloLocation{body_param, {1}}, + HloLocation{body_element_1, {}}, HloLocation{add, {}}, + HloLocation{body_tuple, {1}}, HloLocation{cond_param, {1}})); + + EXPECT_THAT( + GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})), + UnorderedElementsAre(GetValueDefinedAt(constant1))); + EXPECT_THAT( + GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})), + UnorderedElementsAre(GetValueDefinedAt(constant2), + GetValueDefinedAt(xla_while, /*index=*/{1}), + GetValueDefinedAt(body_param, {1}), + GetValueDefinedAt(cond_param, {1}), + GetValueDefinedAt(add))); +} + +TEST_F(HloAliasAnalysisTest, SequentialWhiles) { + // Test sequential while instructions. The while body includes a + // pass-through value. HLO: + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %tuple = Tuple(%constant1, %constant2) + // %while0 = While(%tuple, body, condition) + // %while1 = While(%while0, body, condition) + // return While(%while1, body, condition) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while0 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); + auto xla_while1 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); + auto xla_while2 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}), + analysis.GetUniqueBufferAt(xla_while2, /*index=*/{})); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(xla_while2, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), + analysis.GetUniqueBufferAt(xla_while2, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, NestedWhiles) { + // Test nested while instructions. The inner body passes through element 0 of + // its parameter, and the outer body passes through element 1. HLO: + // + // inner_body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // outer_body((F32[], F32[]) %tuple_param): + // %negate = Negate(%tuple_param{0}) + // %tuple = Tuple(%negate, %tuple_param{1}) + // return While(%tuple, inner_body, condition) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %tuple = Tuple(%constant1, %constant2) + // return While(%tuple, outer_body, condition) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + // Element 0 passes transparently through the body. + auto inner_builder = HloComputation::Builder("inner_body"); + auto inner_param = inner_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto inner_element_0 = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0)); + auto inner_element_1 = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1)); + auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1)); + inner_builder.AddInstruction( + HloInstruction::CreateTuple({inner_element_0, add})); + HloComputation* inner_body = + module_.AddEmbeddedComputation(inner_builder.Build()); + + // Element 1 passes transparently through the body. + auto outer_builder = HloComputation::Builder("outer_body"); + auto outer_param = outer_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto outer_element_0 = outer_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0)); + auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, outer_element_0)); + auto outer_element_1 = outer_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1)); + auto outer_tuple = outer_builder.AddInstruction( + HloInstruction::CreateTuple({negate, outer_element_1})); + auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, condition, inner_body, outer_tuple)); + HloComputation* outer_body = + module_.AddEmbeddedComputation(outer_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto entry_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(entry_while, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(nested_while, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(inner_element_0)); + + EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), + analysis.GetUniqueBufferAt(entry_while, /*index=*/{1})); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), + analysis.GetUniqueBufferAt(nested_while, /*index=*/{1})); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), + analysis.GetUniqueBufferAt(inner_element_1)); +} + +TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { + // Test a while instruction with a body which permutes it's tuple parameter + // elements. HLO: + // + // body((F32[], F32[], F32[]) %tuple_param): + // return Tuple(%tuple_param{1}, %tuple_param{2}, %tuple_param{0}) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %constant3 = Constant(3.0) + // %tuple = Tuple(%constant1, %constant2, %constant3) + // return While(%tuple, body, condition) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_}); + + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto body_element_2 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 2)); + body_builder.AddInstruction(HloInstruction::CreateTuple( + {body_element_1, body_element_2, body_element_0})); + HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2, constant3})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // The swizzling while makes most locations in the module alias leaving only 3 + // HloBuffers. + EXPECT_THAT( + analysis.buffers(), + UnorderedElementsAre(&analysis.GetUniqueBufferAt(constant1), + &analysis.GetUniqueBufferAt(tuple, /*index=*/{}), + &analysis.GetUniqueBufferAt(cond_constant))); + + // The tuple elements of the while and the three constant inputs should all be + // smooshed into the same buffer. + EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}), + analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})); + EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}), + analysis.GetUniqueBufferAt(xla_while, /*index=*/{2})); + EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}), + analysis.GetUniqueBufferAt(constant1)); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(constant2)); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(constant3)); +} + +TEST_F(HloAliasAnalysisTest, TupleSelect) { + // Test a kSelect of a tuple value. Non-top-level element flow through the + // instruction. + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + auto constant4 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + auto tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({constant1})); + auto tuple2 = + builder.AddInstruction(HloInstruction::CreateTuple({constant2})); + auto tuple3 = + builder.AddInstruction(HloInstruction::CreateTuple({constant3})); + auto tuple4 = + builder.AddInstruction(HloInstruction::CreateTuple({constant4})); + const Shape tuple_shape = tuple1->shape(); + auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); + auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); + auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, select12, select34)); + + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // Verify the buffer sets of each select. + EXPECT_THAT(analysis.GetBufferSet(select11, /*index=*/{0}).buffer_ids(), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1).id())); + EXPECT_THAT(analysis.GetBufferSet(select12, /*index=*/{0}).buffer_ids(), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1).id(), + analysis.GetUniqueBufferAt(constant2).id())); + EXPECT_THAT(analysis.GetBufferSet(select34, /*index=*/{0}).buffer_ids(), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant3).id(), + analysis.GetUniqueBufferAt(constant4).id())); + EXPECT_THAT(analysis.GetBufferSet(select1234, /*index=*/{0}).buffer_ids(), + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1).id(), + analysis.GetUniqueBufferAt(constant2).id(), + analysis.GetUniqueBufferAt(constant3).id(), + analysis.GetUniqueBufferAt(constant4).id())); + + EXPECT_FALSE(analysis.GetInstructionBufferSet(select11).IsAmbiguous()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsAmbiguous()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(select34).IsAmbiguous()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(select1234).IsAmbiguous()); + + EXPECT_TRUE(analysis.GetInstructionBufferSet(select11).IsDistinct()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsDistinct()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(select34).IsDistinct()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(select1234).IsDistinct()); +} + +TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { + // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: + // + // body((F32[], F32[]) %tuple_param): + // %negate = Negate(%tuple_param{0}) + // return Tuple(%negate) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %tuple1 = Tuple(%constant1) + // %tuple2 = Tuple(%constant2) + // %select = Select(%tuple1, %tuple2) + // return While(%select, body, condition) + // + auto builder = HloComputation::Builder(TestName()); + + const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, body_element)); + body_builder.AddInstruction(HloInstruction::CreateTuple({negate})); + HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({constant1})); + auto tuple2 = + builder.AddInstruction(HloInstruction::CreateTuple({constant2})); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, select)); + + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // The while should flatten the ambiguous select buffer set so that the buffer + // set contents (constant1 and constant2) becomes a single buffer. + EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(constant2)); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})); + + EXPECT_THAT(GetValuesInBuffer(analysis.GetUniqueBufferAt(constant1)), + UnorderedElementsAre(GetValueDefinedAt(constant1), + GetValueDefinedAt(constant2), + GetValueDefinedAt(xla_while, /*index=*/{0}), + GetValueDefinedAt(body_param, /*index=*/{0}), + GetValueDefinedAt(cond_param, /*index=*/{0}), + GetValueDefinedAt(negate))); + EXPECT_FALSE(analysis.GetInstructionBufferSet(select).IsAmbiguous()); + EXPECT_FALSE(analysis.GetInstructionBufferSet(xla_while).IsAmbiguous()); + + EXPECT_TRUE(analysis.GetInstructionBufferSet(select).IsDistinct()); + EXPECT_TRUE(analysis.GetInstructionBufferSet(xla_while).IsDistinct()); +} + +TEST_F(HloAliasAnalysisTest, Bitcast) { + // Bitcasting a value should not produce a new buffer. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kBitcast, constant)); + + module_.AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.buffers().size(), 1); + + EXPECT_EQ(analysis.GetUniqueBufferAt(constant), + analysis.GetUniqueBufferAt(bitcast)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c55f489494e..ff76cc7bf67 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -35,10 +35,14 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +using ::tensorflow::strings::StrCat; + std::unique_ptr HloComputation::Builder::Build( HloInstruction* root_instruction) { int parameter_count = 0; @@ -52,16 +56,17 @@ std::unique_ptr HloComputation::Builder::Build( root_instruction ? root_instruction : last_added_instruction_; CHECK_NE(nullptr, root); - return WrapUnique( - new HloComputation(name_, parameter_count, &instructions_, root)); + return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, + root, is_fusion_computation_)); } HloComputation::HloComputation( const string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction) + HloInstruction* root_instruction, bool is_fusion_computation) : name_(name), root_instruction_(root_instruction), + is_fusion_computation_(is_fusion_computation), instruction_name_uniquer_(/*separator=*/".") { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; @@ -90,19 +95,85 @@ HloInstruction* HloComputation::AddInstruction( HloInstruction* HloComputation::AddInstructionInternal( std::unique_ptr instruction) { // Generate a unique name for the instruction. - instruction->set_name( - instruction_name_uniquer_.GetUniqueName(instruction->name())); - instruction->set_parent(this); + instruction->UniquifyName(&instruction_name_uniquer_); + Reparent(instruction.get()); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = instructions_.insert(instructions_.end(), std::move(instruction)); return pinst; } -/* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) { - return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv || - opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace || - opcode == HloOpcode::kOutfeed); +HloInstruction* HloComputation::AddParameter( + std::unique_ptr instruction) { + CHECK(instruction->opcode() == HloOpcode::kParameter); + CHECK(is_fusion_computation_); + CHECK(root_instruction_->fusion_instruction() != nullptr); + instruction->SetParentFusion(root_instruction_->fusion_instruction()); + CHECK(root_instruction_->fusion_instruction()->operand_count() == + param_instructions_.size()); + instruction->set_parent(this); + param_instructions_.push_back(instruction.get()); + AddInstructionInternal(std::move(instruction)); + return instructions_.back().get(); +} + +Status HloComputation::RemoveParameter(int64 param_no) { + CHECK_GE(param_no, 0); + CHECK_LT(param_no, param_instructions_.size()); + CHECK(is_fusion_computation_); + CHECK(root_instruction_->fusion_instruction() != nullptr); + HloInstruction* param_instruction = param_instructions_[param_no]; + auto param_instruction_iterator = param_instructions_.begin() + param_no; + param_instructions_.erase(param_instruction_iterator); + // Throw removed fused parameter instruction away. + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + + while (param_no < param_instructions_.size()) { + param_instruction = param_instructions_[param_no]; + string param_name = param_instruction->parameter_name(); + // Fusion parameters are named foo.param_1, bar.param_2, etc. We are + // renumbering the parameters so replace the final number in the name with + // the updated value. + const string param_underscore = ".param_"; + size_t index = param_name.rfind(param_underscore); + if (index == string::npos) { + string after_param = name().substr(index + param_underscore.size()); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + param_name = + StrCat(param_name.substr(0, index), param_underscore, param_no); + } + } + + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_name)); + TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); + new_instr->SetParentFusion(root_instruction_->fusion_instruction()); + param_instructions_[param_no] = new_instr; + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + param_no++; + } + + return Status::OK(); +} + +void HloComputation::Reparent(HloInstruction* instruction) { + instruction->set_parent(this); +} + +bool HloComputation::IsRemovable(const HloInstruction* instruction) { + // If the instruction has control predecessors or successors then we cannot + // remove the instruction without violating ordering constraints (added, for + // example, to avert interference due to buffer aliasing). + if (!instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + return false; + } + const HloOpcode opcode = instruction->opcode(); + return !((opcode == HloOpcode::kParameter && !is_fusion_computation_) || + opcode == HloOpcode::kRecv || opcode == HloOpcode::kSend || + opcode == HloOpcode::kTrace || opcode == HloOpcode::kOutfeed); } Status HloComputation::RemoveInstructionAndUnusedOperands( @@ -110,51 +181,49 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( TF_RET_CHECK(root_instruction() != instruction); TF_RET_CHECK(instruction->user_count() == 0); - TF_RET_CHECK(HloComputation::IsRemovable(instruction->opcode())); - std::queue remove; - remove.push(instruction); - while (!remove.empty()) { - HloInstruction* item = remove.front(); - remove.pop(); - if (item->user_count() != 0 || item == root_instruction_ || - !HloComputation::IsRemovable(item->opcode())) { + TF_RET_CHECK(IsRemovable(instruction)); + std::unordered_set removed; + std::queue worklist; + worklist.push(instruction); + while (!worklist.empty()) { + HloInstruction* item = worklist.front(); + worklist.pop(); + + if (removed.count(item) != 0 || item->user_count() != 0 || + item == root_instruction() || !IsRemovable(item)) { continue; } for (int i = 0; i < item->operand_count(); ++i) { - remove.push(item->mutable_operand(i)); + worklist.push(item->mutable_operand(i)); } - // If an instruction has the same operand more than once, we must not remove - // it again. TF_RETURN_IF_ERROR(RemoveInstruction(item)); + removed.insert(item); } return Status::OK(); } -StatusOr HloComputation::RemoveInstructionIfFound( - HloInstruction* instruction) { - TF_RET_CHECK(IsRemovable(instruction->opcode())); - TF_RET_CHECK(root_instruction() != instruction) - << "cannot remove root instruction"; - TF_RET_CHECK(instruction->user_count() == 0) - << "instruction with users cannot be removed"; - - if (instruction_iterators_.count(instruction) == 0) { - return false; - } +Status HloComputation::RemoveInstruction(HloInstruction* instruction) { VLOG(2) << "Removing instruction " << instruction->name() << " from computation " << name(); + TF_RET_CHECK(IsRemovable(instruction)); + TF_RET_CHECK(root_instruction() != instruction) + << "cannot remove root instruction " << instruction->name(); + TF_RET_CHECK(instruction->user_count() == 0) + << "instruction " << instruction->name() + << " has users and cannot be removed"; + TF_RET_CHECK(instruction->control_predecessors().empty()) + << "instruction " << instruction->name() + << " has control predecessors and cannot be removed"; + TF_RET_CHECK(instruction->control_successors().empty()) + << "instruction " << instruction->name() + << " has control successors and cannot be removed"; + + TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); auto inst_it = instruction_iterators_.at(instruction); (*inst_it)->set_parent(nullptr); instruction->DetachFromOperands(); instructions_.erase(inst_it); - return true; -} - -Status HloComputation::RemoveInstruction(HloInstruction* instruction) { - TF_ASSIGN_OR_RETURN(bool removed, RemoveInstructionIfFound(instruction)); - TF_RET_CHECK(removed) << instruction->ToString() - << " is not a member of computation " << name(); return Status::OK(); } @@ -234,14 +303,14 @@ void ComputeComputationPostOrder( } for (auto& instruction : computation->instructions()) { - for (auto& called_computation : instruction->MakeCalledComputationsSet()) { + for (HloComputation* called_computation : + instruction->called_computations()) { ComputeComputationPostOrder(called_computation, visited, post_order); } } visited->insert(computation); post_order->push_back(computation); - return; } } // namespace @@ -286,22 +355,41 @@ std::list HloComputation::MakeEmbeddedComputationsList() return post_order; } -string HloComputation::ToString() const { +string HloComputation::ToString(int nested_level) const { std::ostringstream s; + for (int i = 0; i < nested_level; i++) { + s << " "; + } s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) << " { \n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (int i = 0; i < nested_level; i++) { + s << " "; + } s << " " << instruction->ToString() << "\n"; if (instruction->opcode() == HloOpcode::kFusion) { - for (const auto& fused_instruction : instruction->fused_instructions()) { - s << " " << fused_instruction->ToString() << "\n"; - } + s << instruction->fused_instructions_computation()->ToString( + nested_level + 1) + << "\n"; } } + for (int i = 0; i < nested_level; i++) { + s << " "; + } s << "}"; return s.str(); } +HloComputationProto HloComputation::ToProto() const { + HloComputationProto proto; + proto.set_name(name_); + for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + HloInstructionProto instruction_proto = instruction->ToProto(); + proto.add_instructions()->Swap(&instruction_proto); + } + return proto; +} + void HloComputation::FuseInstructionsInto( tensorflow::gtl::ArraySlice instructions_to_fuse, HloInstruction* fusion_instruction) { @@ -390,15 +478,6 @@ StatusOr HloComputation::DeepCopyInstruction( } } -Status HloComputation::AddControlDependency(HloInstruction* predecessor, - HloInstruction* successor) { - TF_RET_CHECK(instruction_iterators_.count(predecessor) > 0); - TF_RET_CHECK(instruction_iterators_.count(successor) > 0); - successor->AddControlPredecessor(predecessor); - predecessor->AddControlSuccessor(successor); - return Status::OK(); -} - ProgramShape HloComputation::ComputeProgramShape() const { ProgramShape program_shape; @@ -419,7 +498,9 @@ bool HloComputation::operator==(const HloComputation& other) const { // If are visited but not identical, the recursion should have // been aborted. So, if are visited at this point, they must be // identical. - if (visited.count(std::make_pair(a, b)) > 0) return true; + if (visited.count(std::make_pair(a, b)) > 0) { + return true; + } visited.emplace(a, b); return a->Identical( *b, eq, [](const HloComputation* a, const HloComputation* b) { @@ -442,6 +523,15 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, new_instruction->shape())); VLOG(10) << "transformed " << old_instruction->ToString() << " to " << new_instruction->ToString(); + // Try to add metadata for HLO instructions that are created to replace + // existing HLO instructions (e.g. during optimizations). The assumption is + // that the old instruction and the new instruction would perform the same + // function, and that they would be correlated to the same TF op. This might + // not always be correct since HLO optimizations can cross TF op boundaries. + // But still this seems to be better than nothing. + if (new_instruction->metadata().op_name().empty()) { + new_instruction->set_metadata(old_instruction->metadata()); + } TF_RETURN_IF_ERROR( ReplaceUsesOfInstruction(old_instruction, new_instruction)); return RemoveInstructionAndUnusedOperands(old_instruction); @@ -510,21 +600,46 @@ HloComputation::ComputeTransitiveOperands() const { return result; } -Status HloComputation::Accept(DfsHloVisitor* visitor) const { - // Visit all dead roots. +std::vector HloComputation::CollectUnreachableRoots() const { + std::vector unreachable_roots; for (auto& instruction : instructions()) { if (instruction->user_count() == 0 && instruction->control_successors().empty() && instruction.get() != root_instruction()) { - // Call FinishVisit only at the end. - TF_RETURN_IF_ERROR( - instruction->Accept(visitor, /*call_finish_visit=*/false)); + unreachable_roots.push_back(instruction.get()); } } - // Visit root instruction last. + return unreachable_roots; +} + +Status HloComputation::Accept(DfsHloVisitor* visitor) const { + // Visit unreachable roots. Beware that the visitor might delete the currently + // visited root, which would invalidate iterators if the unreachable roots + // weren't computed ahead of time. + for (HloInstruction* root : CollectUnreachableRoots()) { + // Call FinishVisit only at the end. + TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false)); + } + // Visit the computation root instruction last. return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); } +Status HloComputation::AcceptWithOperandOrder( + DfsHloVisitor* visitor, + const HloInstruction::CompareFunction& operand_order) const { + // Visit unreachable roots. Beware that the visitor might delete the currently + // visited root, which would invalidate iterators if the unreachable roots + // weren't computed ahead of time. + for (HloInstruction* root : CollectUnreachableRoots()) { + TF_RETURN_IF_ERROR( + root->AcceptWithOperandOrder(visitor, operand_order, + /*call_finish_visit=*/false)); + } + // Visit the computation root instruction last. + return root_instruction()->AcceptWithOperandOrder(visitor, operand_order, + /*call_finish_visit=*/true); +} + Status HloComputation::AcceptOrdered( DfsHloVisitor* visitor, const std::vector& order) const { @@ -555,4 +670,44 @@ Status HloComputation::Accept( return this->Accept(&visitor); } +std::unique_ptr HloComputation::Clone(const string& suffix) { + VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; + auto postorder = MakeInstructionPostOrder(); + std::unordered_map clone_map; + std::vector> instructions; + std::unique_ptr new_instr = nullptr; + for (auto instr : postorder) { + std::vector new_operands; + for (auto operand : instr->operands()) { + HloInstruction* new_operand = FindOrDie(clone_map, operand); + CHECK(new_operand != nullptr); + new_operands.push_back(new_operand); + } + + new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands); + InsertOrDie(&clone_map, instr, new_instr.get()); + instructions.push_back(std::move(new_instr)); + } + Builder builder(name() + suffix); + for (auto& instr : instructions) { + builder.AddInstruction(std::move(instr)); + } + auto result = builder.Build( + /*root_instruction=*/FindOrDie(clone_map, root_instruction())); + + // Clone control dependencies. + for (auto instr : postorder) { + HloInstruction* new_instr = FindOrDie(clone_map, instr); + for (auto successor : instr->control_successors()) { + TF_CHECK_OK( + new_instr->AddControlDependencyTo(FindOrDie(clone_map, successor))); + } + } + return result; +} + +void HloComputation::UniquifyName(NameUniquer* name_uniquer) { + name_ = name_uniquer->GetUniqueName(name_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index e78e86b91fd..39074b24e41 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -54,8 +55,10 @@ class HloComputation { // Builder class for HloComputation. class Builder { public: - explicit Builder(const string& name) - : name_(name), last_added_instruction_(nullptr) {} + explicit Builder(const string& name, bool is_fusion_computation = false) + : name_(name), + last_added_instruction_(nullptr), + is_fusion_computation_(is_fusion_computation) {} // Build and return an HloComputation. The parameter root_instruction // specifies the already-added instruction to use as the root. If @@ -74,6 +77,7 @@ class HloComputation { private: const string name_; HloInstruction* last_added_instruction_; + bool is_fusion_computation_; std::vector> instructions_; }; @@ -81,6 +85,16 @@ class HloComputation { // the instruction. HloInstruction* AddInstruction(std::unique_ptr instruction); + // Remove the param_no'th parameter from the computation. + // Note this is only applicatable to the computation for the fusion + // instruction. + Status RemoveParameter(int64 param_no); + + // Add new parameter instruction to the computation. + // This should be a new parameter. Instruction will be appended to parameters + // and inserted to the instruction list. + HloInstruction* AddParameter(std::unique_ptr instruction); + // Remove an instruction from the computation. The instruction must have no // users. Instruction is deallocated with this call. Status RemoveInstruction(HloInstruction* instruction); @@ -111,7 +125,7 @@ class HloComputation { // Returns the parameter instruction for the given parameter number. HloInstruction* parameter_instruction(int64 param_no) const { CHECK_GE(param_no, 0); - CHECK_LT(param_no, param_instructions_.size()); + CHECK_LT(param_no, static_cast(param_instructions_.size())); return param_instructions_[param_no]; } @@ -121,24 +135,20 @@ class HloComputation { const string& name() const { return name_; } + // Use the given NameUniquer to select a unique name for the computation based + // on the computation's existing name. + void UniquifyName(NameUniquer* name_uniquer); + // Return a string representation of the computation. - string ToString() const; + string ToString(int nested_level = 0) const; + + // Returns a serialized representation of this computation. + HloComputationProto ToProto() const; const std::list>& instructions() const { return instructions_; } - // Add a control dependency between the two instructions in this computation - // so that the 'predecessor' is visited before the 'successor' during the DFS - // traversal of the computation. Returns an error status if either of the - // given instructions does not belong to the current computation. - // - // This is used to enforce an additional ordering requirement that is not - // captured by normal data dependencies, such as ordering among Send or Recv - // operations to avoid deadlock. - Status AddControlDependency(HloInstruction* predecessor, - HloInstruction* successor); - // Compute and return a post-order of the instructions in the computation. In // this order, definitions of values always appear before their uses. std::list MakeInstructionPostOrder() const; @@ -205,6 +215,7 @@ class HloComputation { // Set/get the module containing this computation. void set_parent(HloModule* module) { parent_ = module; } const HloModule* parent() const { return parent_; } + HloModule* parent() { return parent_; } // Visit every node in the computation in DFS post-order with the given // visitor. This is similar to calling HloInstruction::Accept on the root of @@ -214,6 +225,13 @@ class HloComputation { // root instruction as the argument). Status Accept(DfsHloVisitor* visitor) const; + // Same as Accept() above, but the order of operand and control predecessor + // visitation is determined by the given operand order; if compare(A, B) == + // true, A is visited before B. + Status AcceptWithOperandOrder( + DfsHloVisitor* visitor, + const HloInstruction::CompareFunction& operand_order) const; + // Visit every node in the computation in the given order. 'order' must // be a topological sort of all instructions in the computation. Status AcceptOrdered(DfsHloVisitor* visitor, @@ -222,26 +240,32 @@ class HloComputation { // Same as Accept() above, but the visitor is given as a function. Status Accept(const FunctionVisitor::VisitorFunction& visitor_func) const; - // Returns true if instructions of the given opcode can be removed from the + // Returns a deep copy of this computation including all instructions. + std::unique_ptr Clone(const string& suffix = "clone"); + + // Returns true if the given instruction can be removed from the // computation. Instructions such as parameters and send/receive instructions // cannot be removed without violating invariants of the HLO computation or - // module. - static bool IsRemovable(const HloOpcode& opcode); + // module with the exception of fusion computation. A parameter instruction + // is removable for a fusion computation. + bool IsRemovable(const HloInstruction* instruction); + + // Returns if this computation is a fusion computation. + bool IsFusionComputation() const { return is_fusion_computation_; } private: explicit HloComputation( const string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction); + HloInstruction* root_instruction, bool is_fusion_computation = false); // Internal helper for adding instructions. HloInstruction* AddInstructionInternal( std::unique_ptr instruction); - // Remove an instruction from the computation if found. The instruction must - // have no users. Instruction is deallocated with this call. - // Return whether instruction was found and removed. - StatusOr RemoveInstructionIfFound(HloInstruction* instruction); + // Helper for setting the parent of instructions that are added to this + // computation. + void Reparent(HloInstruction* instruction); // Fuses HLOs in instructions_to_fuse into fusion_instruction. // @@ -254,9 +278,15 @@ class HloComputation { // of the given instruction. The given instruction must be tuple-shaped. StatusOr DeepCopyTuple(HloInstruction* instruction); - const string name_; + // Internal helper to collect unreachable roots. + std::vector CollectUnreachableRoots() const; + + string name_; HloInstruction* root_instruction_; + // A tag shows if this is a fusion computation. + bool is_fusion_computation_; + // Module containing this computation. HloModule* parent_ = nullptr; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 1e0d09b72c7..5d49c83e2d0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -20,15 +20,22 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + class HloComputationTest : public HloTestBase { protected: HloComputationTest() {} @@ -67,8 +74,8 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { auto negate_computation = CreateNegateComputation(); auto map_computation = CreateMapComputation(negate_computation.get()); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); - EXPECT_EQ(map_computation->MakeEmbeddedComputationsList().front(), - negate_computation.get()); + EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(), + ElementsAre(negate_computation.get())); } TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { @@ -93,10 +100,10 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { // GetEmbeddedComputations returns a post order of the embedded computations, // so the negate computation must come first. EXPECT_EQ(negate_computation.get(), *embedded_computations.begin()); - EXPECT_MATCH(testing::ListToVec(embedded_computations), - testing::UnorderedMatcher( - negate_computation.get(), map1_computation.get(), - map2_computation.get())); + EXPECT_THAT( + embedded_computations, + UnorderedElementsAre(negate_computation.get(), map1_computation.get(), + map2_computation.get())); } TEST_F(HloComputationTest, PostOrderSingleton) { @@ -106,7 +113,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto computation = builder.Build(); - EXPECT_EQ(computation->MakeInstructionPostOrder().front(), constant); + EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); } TEST_F(HloComputationTest, PostOrderSimple) { @@ -121,10 +128,8 @@ TEST_F(HloComputationTest, PostOrderSimple) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); auto computation = builder.Build(); - EXPECT_MATCH( - testing::ListToVec( - computation->MakeInstructionPostOrder()), - testing::OrderedMatcher(constant, negate1, negate2)); + EXPECT_THAT(computation->MakeInstructionPostOrder(), + ElementsAre(constant, negate1, negate2)); } TEST_F(HloComputationTest, PostOrderTrace) { @@ -141,10 +146,8 @@ TEST_F(HloComputationTest, PostOrderTrace) { auto computation = builder.Build(); // Trace instructions should be at the end of the sort. - EXPECT_MATCH(testing::ListToVec( - computation->MakeInstructionPostOrder()), - testing::OrderedMatcher(constant, negate1, - negate2, trace)); + EXPECT_THAT(computation->MakeInstructionPostOrder(), + ElementsAre(constant, negate1, negate2, trace)); } TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { @@ -161,10 +164,8 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto computation = builder.Build(); - EXPECT_MATCH(testing::ListToVec( - computation->MakeInstructionPostOrder()), - testing::UnorderedMatcher( - constant1, constant2, constant3, constant4)); + EXPECT_THAT(computation->MakeInstructionPostOrder(), + UnorderedElementsAre(constant1, constant2, constant3, constant4)); } TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { @@ -187,9 +188,8 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { auto post_order = computation->MakeInstructionPostOrder(); EXPECT_EQ(6, post_order.size()); - EXPECT_MATCH(testing::ListToVec(post_order), - testing::UnorderedMatcher( - constant1, constant2, constant3, add1, add2, add3)); + EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3, + add1, add2, add3)); } TEST_F(HloComputationTest, VisitWithMultipleRoots) { @@ -253,8 +253,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); - EXPECT_EQ(HloOpcode::kCopy, copy->opcode()); - EXPECT_EQ(constant, copy->operand(0)); + EXPECT_THAT(copy, op::Copy(constant)); } TEST_F(HloComputationTest, DeepCopyTuple) { @@ -271,18 +270,10 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); - EXPECT_EQ(HloOpcode::kTuple, tuple_copy->opcode()); - EXPECT_EQ(HloOpcode::kCopy, tuple_copy->operand(0)->opcode()); - const HloInstruction* gte0 = tuple_copy->operand(0)->operand(0); - EXPECT_EQ(HloOpcode::kGetTupleElement, gte0->opcode()); - EXPECT_EQ(0, gte0->tuple_index()); - EXPECT_EQ(tuple, gte0->operand(0)); - - EXPECT_EQ(HloOpcode::kCopy, tuple_copy->operand(1)->opcode()); - const HloInstruction* gte1 = tuple_copy->operand(1)->operand(0); - EXPECT_EQ(HloOpcode::kGetTupleElement, gte1->opcode()); - EXPECT_EQ(1, gte1->tuple_index()); - EXPECT_EQ(tuple, gte1->operand(0)); + EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), + op::Copy(op::GetTupleElement(tuple)))); + EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index()); + EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index()); } TEST_F(HloComputationTest, CycleDetection) { @@ -297,15 +288,74 @@ TEST_F(HloComputationTest, CycleDetection) { auto computation = builder.Build(); // Add a control dependency to create a cycle. - ASSERT_IS_OK(computation->AddControlDependency(add, negate)); + ASSERT_IS_OK(add->AddControlDependencyTo(negate)); const auto visitor = [](HloInstruction* instruction) { return Status::OK(); }; auto visit_status = computation->Accept(visitor); ASSERT_FALSE(visit_status.ok()); - ASSERT_MATCH(visit_status.error_message(), - testing::ContainsRegex("cycle is detecte")); + ASSERT_THAT(visit_status.error_message(), + ::testing::ContainsRegex("cycle is detecte")); +} + +TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { + // Test RemoveInstructionAndUnusedOperands with an instruction which has a + // duplicated (dead) operand. This verifies that the operand is not deleted + // twice. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto dead_negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); + auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, dead_negate, dead_negate)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); + auto computation = builder.Build(); + + EXPECT_EQ(4, computation->instruction_count()); + EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_EQ(negate, computation->root_instruction()); + + ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add)); + + EXPECT_EQ(2, computation->instruction_count()); + EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_EQ(negate, computation->root_instruction()); +} + +TEST_F(HloComputationTest, CloneWithControlDependency) { + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); + auto computation = builder.Build(/*root_instruction=*/add); + + TF_CHECK_OK(negate->AddControlDependencyTo(add)); + + auto clone = computation->Clone(); + + auto cloned_add = clone->root_instruction(); + EXPECT_EQ(cloned_add->opcode(), HloOpcode::kAdd); + + auto predecessors = cloned_add->control_predecessors(); + EXPECT_EQ(1, predecessors.size()); + EXPECT_EQ(HloOpcode::kNegate, predecessors[0]->opcode()); + auto successors = predecessors[0]->control_successors(); + EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc new file mode 100644 index 00000000000..93f448e7018 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -0,0 +1,79 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr HloConstantFolding::Run(HloModule* module) { + auto evaluator = MakeUnique(); + + XLA_VLOG_LINES(2, + "HloConstantFolding::Run(), before:\n" + module->ToString()); + bool changed = false; + + for (auto& computation : module->computations()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { + // Skip dead code. + if (instruction->user_count() == 0 && + computation->root_instruction() != instruction) { + continue; + } + // Skip Constant and Parameter operation. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant) { + continue; + } + // Skip instructions with non-constant operands. + if (!hlo_query::AllOperandsAreConstants(*instruction)) { + continue; + } + + std::unique_ptr result = evaluator->TryEvaluate(instruction); + // Currently we skip unimplemented operations. + // TODO(b/35975797): Fold constant computations for more operations. + if (result == nullptr) { + VLOG(2) << "Constant folding failed for instruction: " + << instruction->ToString(); + continue; + } + + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, HloInstruction::CreateConstant(std::move(result)))); + changed = true; + } + } + XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h new file mode 100644 index 00000000000..331480bd029 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CONSTANT_FOLDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CONSTANT_FOLDING_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which performs constant folding in order to avoid unnecessary +// computation on constants. +class HloConstantFolding : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "constant_folding"; } + + // Run constant folding operations on the given module. Returns whether the + // module was changed (constant expressions folded). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CONSTANT_FOLDING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc new file mode 100644 index 00000000000..31b81052cb2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -0,0 +1,218 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" + +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using HloConstantFoldingTest = HloTestBase; + +TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ(LiteralUtil::GetFirstElement( + computation->root_instruction()->literal()), + 42); +} + +TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ(LiteralUtil::GetFirstElement( + computation->root_instruction()->literal()), + 42.0f); +} + +TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0f, 19.0f}))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ( + LiteralUtil::Get(computation->root_instruction()->literal(), {0}), + 42); + EXPECT_EQ( + LiteralUtil::Get(computation->root_instruction()->literal(), {1}), + 19); +} + +TEST_F(HloConstantFoldingTest, Concatenate) { + const struct TestConfig { + int concat_dimension; + tensorflow::gtl::ArraySlice dimensions; + tensorflow::gtl::ArraySlice concat_sizes; + } test_configs[] = { + {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, + {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, + }; + + for (auto& test_config : test_configs) { + HloComputation::Builder builder(TestName()); + std::vector dimensions(test_config.dimensions.begin(), + test_config.dimensions.end()); + int64 concat_size = 0; + std::vector operands; + for (auto csize : test_config.concat_sizes) { + dimensions[test_config.concat_dimension] = csize; + concat_size += csize; + auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); + HloInstruction* insn = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + operands.push_back(insn); + } + dimensions[test_config.concat_dimension] = concat_size; + Shape shape = ShapeUtil::MakeShape(F32, dimensions); + builder.AddInstruction(HloInstruction::CreateConcatenate( + shape, operands, test_config.concat_dimension)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); + } +} + +TEST_F(HloConstantFoldingTest, Slice) { + HloComputation::Builder builder(TestName()); + const int64 dimensions[] = {11, 8, 7, 5, 9}; + const int64 slice_start[] = {4, 2, 3, 1, 5}; + const int64 slice_limits[] = {10, 8, 6, 5, 9}; + const int64 slice_strides[] = {1, 1, 1, 1, 1}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); + builder.AddInstruction(HloInstruction::CreateSlice( + shape, literal_instruction, slice_start, slice_limits, slice_strides)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); +} + +TEST_F(HloConstantFoldingTest, TransposeConstantFold) { + HloComputation::Builder builder(TestName()); + const int64 dimensions[] = {11, 8, 7, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = LiteralUtil::CloneToUnique(*literal); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); + const int64 permutation[] = {1, 2, 0, 4, 3}; + builder.AddInstruction( + HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape)); + + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + bool matched = true; + LiteralUtil::EachCell( + root->literal(), + [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + matched = matched && (value == LiteralUtil::Get(*literal_clone, + rindexes)); + }); + EXPECT_TRUE(matched); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 2866f8158d5..38cc74b0f1e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -20,10 +20,43 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { +Status HloCostAnalysis::Preprocess(HloInstruction* hlo) { + // Set current instruction cost values to reasonable default values. Each + // handler can overwrite these values. In Postprocess, these value are + // accumulated and written to the per-instruction maps. + current_flop_count_ = 0; + current_transcendental_count_ = 0; + + // The default element count for an instruction is the sum of elements in the + // operands and output. The default ShapeUtil::ByteSizeOf does not handle + // opaque types. + current_bytes_accessed_ = shape_size_(hlo->shape()); + for (const HloInstruction* operand : hlo->operands()) { + current_bytes_accessed_ += shape_size_(operand->shape()); + } + + return Status::OK(); +} + +Status HloCostAnalysis::Postprocess(HloInstruction* hlo) { + // Accumulate cost values and write into per-instruction maps. + flop_count_ += current_flop_count_; + hlo_to_flop_count_[hlo] = current_flop_count_; + + transcendental_count_ += current_transcendental_count_; + hlo_to_transcendental_count_[hlo] = current_transcendental_count_; + + bytes_accessed_ += current_bytes_accessed_; + hlo_to_bytes_accessed_[hlo] = current_bytes_accessed_; + + return Status::OK(); +} + Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { const auto& shape = hlo_instruction->shape(); // For element-wise operations, the number of computations is the same as the @@ -32,12 +65,11 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { auto opcode = hlo_instruction->opcode(); // We treat the two opcodes (kExp, kPower) as transcendental operations. if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower) { - transcendental_count_ += computation_count; + current_transcendental_count_ = computation_count; } else { // Note: transcendental operations are considered a separate category from // FLOPs. - hlo_to_flop_count_[hlo_instruction] = computation_count; - flop_count_ += computation_count; + current_flop_count_ = computation_count; } return Status::OK(); } @@ -69,16 +101,21 @@ Status HloCostAnalysis::HandleClamp(HloInstruction* clamp, } Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) { + current_bytes_accessed_ = 0; return Status::OK(); } Status HloCostAnalysis::HandleConstant(HloInstruction* constant, const Literal& literal) { + current_bytes_accessed_ = 0; return Status::OK(); } Status HloCostAnalysis::HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) { + // GetTupleElement forwards a pointer and does not touch each element in the + // output. + current_bytes_accessed_ = 0; return Status::OK(); } @@ -99,9 +136,9 @@ Status HloCostAnalysis::HandleSlice(HloInstruction* slice, return Status::OK(); } -Status HloCostAnalysis::HandleDynamicSlice( - HloInstruction* slice, - tensorflow::gtl::ArraySlice operands) { +Status HloCostAnalysis::HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* start_indices) { return Status::OK(); } @@ -114,6 +151,10 @@ Status HloCostAnalysis::HandleDynamicUpdateSlice( Status HloCostAnalysis::HandleTuple( HloInstruction* tuple, tensorflow::gtl::ArraySlice operands) { + // The tuple instruction only gathers pointers from inputs (it doesn't iterate + // through them). The memory touched is then only the size of the output + // buffer. + current_bytes_accessed_ = shape_size_(tuple->shape()); return Status::OK(); } @@ -125,8 +166,7 @@ Status HloCostAnalysis::HandleConcatenate( Status HloCostAnalysis::HandleConvert(HloInstruction* convert, HloInstruction* operand) { - flop_count_ += ShapeUtil::ElementsIn(operand->shape()); - return Status::OK(); + return HandleElementwiseOp(convert); } Status HloCostAnalysis::HandleCopy(HloInstruction* copy, @@ -137,15 +177,24 @@ Status HloCostAnalysis::HandleCopy(HloInstruction* copy, Status HloCostAnalysis::HandleDot(HloInstruction* dot, HloInstruction* lhs_instruction, HloInstruction* rhs_instruction) { + const Shape& lhs_shape = lhs_instruction->shape(); + const Shape& rhs_shape = rhs_instruction->shape(); + // Count of elements along the reduction dimension (last dimension for the + // rhs). + int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1); + + // First divide by reduction width before multiplying by rhs elements to avoid + // overflow. + int64 fma_count; + if (reduction_width == 0) { + fma_count = 0; + } else { + fma_count = (ShapeUtil::ElementsIn(lhs_shape) / reduction_width) * + ShapeUtil::ElementsIn(rhs_shape); + } + // We count an FMA operation as 2 floating point operations. - // Multiplying the sizes of lhs, rhs, and result produces the square of the - // number of FMAs during the computation. - auto fma_count = std::sqrt( - static_cast(ShapeUtil::ElementsIn(lhs_instruction->shape())) * - ShapeUtil::ElementsIn(rhs_instruction->shape()) * - ShapeUtil::ElementsIn(dot->shape())); - flop_count_ += 2 * fma_count; - hlo_to_flop_count_[dot] = 2 * fma_count; + current_flop_count_ = kFmaFlops * fma_count; return Status::OK(); } @@ -163,15 +212,14 @@ Status HloCostAnalysis::HandleMap( tensorflow::gtl::ArraySlice /*static_operands*/) { // Compute the cost of the user function. HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor; + HloCostAnalysis visitor(shape_size_); TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); // Compute the cost of all elements for this Map operation. - auto element_count = ShapeUtil::ElementsIn(map->shape()); - transcendental_count_ += element_count * visitor.transcendental_count(); - auto hlo_flop_count = element_count * visitor.flop_count(); - hlo_to_flop_count_[map] = hlo_flop_count; - flop_count_ += hlo_flop_count; + int64 element_count = ShapeUtil::ElementsIn(map->shape()); + current_transcendental_count_ = + element_count * visitor.transcendental_count(); + current_flop_count_ = element_count * visitor.flop_count(); return Status::OK(); } @@ -180,16 +228,15 @@ Status HloCostAnalysis::HandleReduce( tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { // Compute the cost of the user function. HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor; + HloCostAnalysis visitor(shape_size_); TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); // Compute the cost of all elements for this Reduce operation. - auto reduction_count = ShapeUtil::ElementsIn(arg->shape()) - - ShapeUtil::ElementsIn(reduce->shape()); - auto hlo_flop_count = reduction_count * visitor.flop_count(); - hlo_to_flop_count_[reduce] = hlo_flop_count; - flop_count_ += hlo_flop_count; - transcendental_count_ += reduction_count * visitor.transcendental_count(); + int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) - + ShapeUtil::ElementsIn(reduce->shape()); + current_flop_count_ = reduction_count * visitor.flop_count(); + current_transcendental_count_ = + reduction_count * visitor.transcendental_count(); return Status::OK(); } @@ -199,7 +246,7 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window, HloComputation* function) { // Compute the cost of the user function. HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor; + HloCostAnalysis visitor(shape_size_); TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); // Compute the cost of all elements for this ReduceWindow operation. For each @@ -209,10 +256,8 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window, for (const auto& dimension : window.dimensions()) { window_size *= dimension.size(); } - auto hlo_flop_count = output_size * (window_size - 1) * visitor.flop_count(); - hlo_to_flop_count_[reduce_window] = hlo_flop_count; - flop_count_ += hlo_flop_count; - transcendental_count_ += + current_flop_count_ = output_size * (window_size - 1) * visitor.flop_count(); + current_transcendental_count_ = output_size * (window_size - 1) * visitor.transcendental_count(); return Status::OK(); } @@ -220,10 +265,10 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window, Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { // Compute the cost of the select and scatter function. HloInstruction* select = instruction->select()->root_instruction(); - HloCostAnalysis select_visitor; + HloCostAnalysis select_visitor(shape_size_); TF_RETURN_IF_ERROR(select->Accept(&select_visitor)); HloInstruction* scatter = instruction->scatter()->root_instruction(); - HloCostAnalysis scatter_visitor; + HloCostAnalysis scatter_visitor(shape_size_); TF_RETURN_IF_ERROR(scatter->Accept(&scatter_visitor)); // Compute the cost of all elements for this operation. For each scatter @@ -235,12 +280,10 @@ Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { for (const auto& dimension : instruction->window().dimensions()) { window_size *= dimension.size(); } - auto hlo_flop_count = + current_flop_count_ = source_element_count * ((window_size - 1) * select_visitor.flop_count() + scatter_visitor.flop_count()); - hlo_to_flop_count_[instruction] = hlo_flop_count; - flop_count_ += hlo_flop_count; - transcendental_count_ += + current_transcendental_count_ = source_element_count * ((window_size - 1) * select_visitor.transcendental_count() + scatter_visitor.transcendental_count()); @@ -248,6 +291,8 @@ Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { } Status HloCostAnalysis::HandleBitcast(HloInstruction* bitcast) { + // A bitcast does no computation and touches no memory. + current_bytes_accessed_ = 0; return Status::OK(); } @@ -286,10 +331,7 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution, const int64 fmas_per_output_element = ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features; const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); - const double hlo_flop_count = static_cast(output_elements) * - fmas_per_output_element * kFmaFlops; - flop_count_ += hlo_flop_count; - hlo_to_flop_count_[convolution] = hlo_flop_count; + current_flop_count_ = output_elements * fmas_per_output_element * kFmaFlops; return Status::OK(); } @@ -299,9 +341,7 @@ Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) { // // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. - const double hlo_flop_count = ShapeUtil::ElementsIn(crs->shape()); - flop_count_ += hlo_flop_count; - hlo_to_flop_count_[crs] = hlo_flop_count; + current_flop_count_ = ShapeUtil::ElementsIn(crs->shape()); return Status::OK(); } @@ -310,27 +350,32 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random, // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume // the cost of each RNG is same as a transcendental operation. - transcendental_count_ += ShapeUtil::ElementsIn(random->shape()); + current_transcendental_count_ = ShapeUtil::ElementsIn(random->shape()); return Status::OK(); } Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { // Compute the cost of the fused expression. HloInstruction* fused_expression_root = fusion->fused_expression_root(); - HloCostAnalysis visitor; + // Don't compute sizes inside of fused ops. We don't use the size here and the + // operations inside might not have a layout. + HloCostAnalysis visitor([](const Shape&) { return 0; }); TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor)); // Attribute the cost of the fused expression to the fusion node. - transcendental_count_ += visitor.transcendental_count(); - hlo_to_flop_count_[fusion] += visitor.flop_count(); - flop_count_ += visitor.flop_count(); + current_transcendental_count_ = visitor.transcendental_count(); + current_flop_count_ = visitor.flop_count(); return Status::OK(); } -Status HloCostAnalysis::HandleCall( - HloInstruction* call, tensorflow::gtl::ArraySlice operands, - HloComputation* computation) { - return Unimplemented("call"); +Status HloCostAnalysis::HandleCall(HloInstruction* call) { + HloCostAnalysis computation_visitor(shape_size_); + TF_RETURN_IF_ERROR(call->to_apply()->Accept(&computation_visitor)); + + current_flop_count_ = computation_visitor.flop_count(); + current_transcendental_count_ = computation_visitor.transcendental_count(); + current_bytes_accessed_ = computation_visitor.bytes_accessed(); + return Status::OK(); } Status HloCostAnalysis::HandleCustomCall( @@ -343,28 +388,49 @@ Status HloCostAnalysis::HandleCustomCall( Status HloCostAnalysis::HandleSort(HloInstruction* sort, HloInstruction* operand_instruction) { // The cost of sort is implementation dependent, so cannot determine at HLO - // level. Maybe just assume the comparison based N*log(N) sorting? - // TODO(b/26346211): Implement the cost model for sort. - return Unimplemented("HandleSort"); + // level. Assume comparison based N*log(N) sorting. + int64 elements = ShapeUtil::ElementsIn(operand_instruction->shape()); + current_flop_count_ = elements * tensorflow::Log2Ceiling(elements); + return Status::OK(); } -Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while, - HloInstruction* init, - HloComputation* condition, - HloComputation* body) { +Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) { // Since the number of iterations of the while node is not statically - // determined, we cannot analyze the computation cost of a while node. - // TODO(b/26346211): Add cost analysis for while node. - return Unimplemented("HandleWhile"); + // determined, we cannot precisely compute the cost of a while node. For now + // compute the cost of a single iteration. + // TODO(b/26346211): Improve the cost analysis for while node. + HloCostAnalysis body_visitor(shape_size_); + TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&body_visitor)); + HloCostAnalysis condition_visitor(shape_size_); + TF_RETURN_IF_ERROR(xla_while->while_condition()->Accept(&condition_visitor)); + + current_flop_count_ = + body_visitor.flop_count() + condition_visitor.flop_count(); + current_transcendental_count_ = body_visitor.transcendental_count() + + condition_visitor.transcendental_count(); + current_bytes_accessed_ = + body_visitor.bytes_accessed() + condition_visitor.bytes_accessed(); + + return Status::OK(); } Status HloCostAnalysis::FinishVisit(HloInstruction* root) { return Status::OK(); } -double HloCostAnalysis::hlo_to_flop_count(const HloInstruction& hlo) const { +int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const { auto it = hlo_to_flop_count_.find(&hlo); - return it == hlo_to_flop_count_.end() ? 0.0 : it->second; + return it == hlo_to_flop_count_.end() ? 0 : it->second; +} + +int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const { + auto it = hlo_to_transcendental_count_.find(&hlo); + return it == hlo_to_transcendental_count_.end() ? 0 : it->second; +} + +int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const { + auto it = hlo_to_bytes_accessed_.find(&hlo); + return it == hlo_to_bytes_accessed_.end() ? 0 : it->second; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 2377b5b9be1..b2c40f75ca4 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -35,8 +36,11 @@ namespace xla { // operations separately from transcendental operations. class HloCostAnalysis : public DfsHloVisitor { public: - HloCostAnalysis() = default; - ~HloCostAnalysis() override = default; + // shape_size is a function which returns the size in bytes of the top-level + // buffer of a shape. + using ShapeSizeFunction = std::function; + explicit HloCostAnalysis(const ShapeSizeFunction& shape_size) + : shape_size_(shape_size) {} Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, HloInstruction* operand) override; @@ -80,16 +84,14 @@ class HloCostAnalysis : public DfsHloVisitor { tensorflow::gtl::ArraySlice dimensions, HloComputation* function_handle) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call, - tensorflow::gtl::ArraySlice operands, - HloComputation* computation) override; + Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) override; Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; - Status HandleDynamicSlice( - HloInstruction* slice, - tensorflow::gtl::ArraySlice operands) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* start_indices) override; Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* operand, HloInstruction* update, @@ -111,34 +113,54 @@ class HloCostAnalysis : public DfsHloVisitor { Status HandlePad(HloInstruction* pad) override; Status HandleReshape(HloInstruction* reshape) override; Status HandleTranspose(HloInstruction* transpose) override; - Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, HloComputation* body) override; + Status HandleWhile(HloInstruction* xla_while) override; Status FinishVisit(HloInstruction* root) override; - // Returns the amount of computations in the graph. - double flop_count() { return flop_count_; } - double transcendental_count() { return transcendental_count_; } + Status Preprocess(HloInstruction* hlo) override; + Status Postprocess(HloInstruction* hlo) override; - // Resolves the provided HLO instruction to a flop count, or 0 if the HLO was - // not found to have a flop count in the analysis. - double hlo_to_flop_count(const HloInstruction& hlo) const; + // Returns the amount of computations in the graph. + int64 flop_count() const { return flop_count_; } + int64 transcendental_count() const { return transcendental_count_; } + + // Returns the respective cost computed for a particular HLO instruction, or 0 + // if the HLO was not found to have a cost in the analysis. + int64 flop_count(const HloInstruction& hlo) const; + int64 transcendental_count(const HloInstruction& hlo) const; + + // Returns the number of bytes read/written. + int64 bytes_accessed(const HloInstruction& hlo) const; + int64 bytes_accessed() const { return bytes_accessed_; } private: - // An FMA counts as two floating point operations in these analyses. + // An FMA counts as two floating point operations in these analyzes. static constexpr int64 kFmaFlops = 2; // Utility function to handle all element-wise operations. Status HandleElementwiseOp(HloInstruction* hlo_instruction); - // Mapping from HLO instructions to the flop count we computed for them in the + // Function which computes the size of the top-level of a given shape (not + // including nested elements, if any). If null then bytes_accessed methods + // return an error. + const ShapeSizeFunction shape_size_; + + // The total number of floating point operations, transcendental operations, + // and bytes accesses (read or written) in the computation. + int64 flop_count_ = 0; + int64 transcendental_count_ = 0; + int64 bytes_accessed_ = 0; + + // Cost counts of the current instruction. These should be set by each + // handlers if different from the default values computed in Preprocess. + int64 current_flop_count_; + int64 current_transcendental_count_; + int64 current_bytes_accessed_; + + // Mapping from HLO instructions to the cost we computed for them in the // course of the graph analysis. - std::map hlo_to_flop_count_; - - // The number of floating point operations in the graph. - double flop_count_ = 0.0; - - // The number of transcendental operations in the graph. - double transcendental_count_ = 0.0; + std::map hlo_to_flop_count_; + std::map hlo_to_transcendental_count_; + std::map hlo_to_bytes_accessed_; TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis); }; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index f55d939b42e..b74c7eb4e07 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -39,6 +39,12 @@ limitations under the License. namespace xla { namespace { +constexpr int64 kPointerSize = 8; + +int64 ShapeSize(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + // This test suite tests the HLO cost analysis by first building a computation // using the client computation builder and running the HloCostAnalysis that // returns the number of floating point and transcendental operations in the @@ -48,7 +54,7 @@ class HloCostAnalysisTest : public ::testing::Test { HloCostAnalysisTest() : client_(ClientLibrary::LocalClientOrDie()), // Accessing service instance is required for the unit tests to enable - // whitebox acccesses to the user computation built from the client, + // whitebox accesses to the user computation built from the client, // as shown in the BuildHloGraph functions below. service_(static_cast(ClientLibrary::GetXlaService( static_cast(client_)->platform()))), @@ -121,7 +127,8 @@ class HloCostAnalysisTest : public ::testing::Test { VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); return std::move( - computation_tracker_.BuildHloModule(versioned_handle).ValueOrDie()); + computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig()) + .ValueOrDie()); } Client* client_; @@ -144,12 +151,18 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis; + HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); // Check the number of computations returned from the analysis (1500 FMAs). EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30)); } TEST_F(HloCostAnalysisTest, Map) { @@ -159,13 +172,14 @@ TEST_F(HloCostAnalysisTest, Map) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis; + HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); // add contributes to 10 flops and exp contributes to 10 transcendental ops. EXPECT_EQ(analysis.flop_count(), 10); EXPECT_EQ(analysis.transcendental_count(), 10); + EXPECT_EQ(analysis.bytes_accessed(), 80); } TEST_F(HloCostAnalysisTest, Convolution) { @@ -182,13 +196,17 @@ TEST_F(HloCostAnalysisTest, Convolution) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis; + HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); // Output shape is [1x1x8x18] and each output element requires (3x3) // FMAs and one FMA is 2 flops. EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18)); } TEST_F(HloCostAnalysisTest, Reduce) { @@ -200,7 +218,7 @@ TEST_F(HloCostAnalysisTest, Reduce) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis; + HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -218,7 +236,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis; + HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -238,7 +256,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis; + HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -251,7 +269,7 @@ TEST_F(HloCostAnalysisTest, Broadcast) { ComputationBuilder b(client_, "broadcast"); b.Broadcast(b.ConstantR0(42), {10, 7}); auto hlo_module = BuildHloGraph(&b); - HloCostAnalysis analysis; + HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); EXPECT_EQ(analysis.flop_count(), 0); @@ -271,7 +289,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis; + HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -282,7 +300,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { } TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { - HloCostAnalysis conv_analysis; + HloCostAnalysis conv_analysis(ShapeSize); { ComputationBuilder builder(client_, "conv_looking_matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), @@ -295,7 +313,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { &conv_analysis)); } - HloCostAnalysis matmul_analysis; + HloCostAnalysis matmul_analysis(ShapeSize); { ComputationBuilder builder(client_, "matmul"); auto lhs = @@ -311,28 +329,6 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count()); } -// Note that we still expect that any given operation won't overflow 2^64 FLOPs, -// just that the sum total may. -TEST_F(HloCostAnalysisTest, TotalOverflowsInt64) { - HloCostAnalysis matmul_analysis; - { - ComputationBuilder builder(client_, "matmul"); - auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {1, 1LL << 62}), - "input"); - auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {1LL << 62, 1}), - "weights"); - auto a = builder.Dot(lhs, rhs); - auto b = builder.Dot(a, lhs); - builder.Dot(b, rhs); - auto hlo_module = BuildHloGraph(&builder); - ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( - &matmul_analysis)); - } - - LOG(INFO) << matmul_analysis.flop_count(); - EXPECT_GT(matmul_analysis.flop_count(), std::numeric_limits::max()); -} - using FusionCostAnalysis = ::testing::Test; TEST_F(FusionCostAnalysis, LoopFusion) { @@ -373,12 +369,57 @@ TEST_F(FusionCostAnalysis, LoopFusion) { fusion->FuseInstruction(clamp.get()); fusion->FuseInstruction(add.get()); - HloCostAnalysis fusion_analysis; + HloCostAnalysis fusion_analysis(ShapeSize); ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); EXPECT_EQ(fusion_analysis.flop_count(), 16); EXPECT_EQ(fusion_analysis.transcendental_count(), 4); } +TEST_F(FusionCostAnalysis, NoLayout) { + Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5}); + // Instructions within a fused op may have no layout. + Shape shape_without_layout = shape_with_layout; + shape_without_layout.clear_layout(); + + auto c1 = HloInstruction::CreateConstant( + LiteralUtil::CreateR4FromArray4D(Array4D(2, 3, 4, 5))); + auto c2 = + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); + + auto broadcast = + HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1}); + auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd, + c1.get(), broadcast.get()); + + auto fusion = HloInstruction::CreateFusion( + shape_with_layout, HloInstruction::FusionKind::kLoop, add.get()); + fusion->FuseInstruction(broadcast.get()); + + HloCostAnalysis fusion_analysis(ShapeSize); + ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); + + EXPECT_EQ(fusion_analysis.flop_count(), 120); + EXPECT_EQ(fusion_analysis.transcendental_count(), 0); +} + +TEST_F(HloCostAnalysisTest, TupleCost) { + HloCostAnalysis analysis(ShapeSize); + { + ComputationBuilder builder(client_, "matmul"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); + auto tuple = builder.Tuple({x, y}); + auto hlo_module = BuildHloGraph(&builder); + + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + } + + EXPECT_EQ(analysis.flop_count(), 0); + EXPECT_EQ(analysis.transcendental_count(), 0); + EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index ec8161f55fd..cc39c3ac203 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -36,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -54,7 +57,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); @@ -84,17 +87,19 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_NE(add->operand(0), add->operand(1)); + EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); - EXPECT_EQ(add->operand(0), add->operand(1)); + auto first_operand = add->operand(0); + EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2)); + EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); @@ -114,19 +119,17 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(constant1, add->operand(0)); - EXPECT_EQ(constant2, add->operand(1)); + EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(constant1, add->operand(0)); - EXPECT_EQ(constant2, add->operand(1)); + EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); @@ -153,13 +156,13 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(7, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(6, computation->instruction_count()); } @@ -181,20 +184,22 @@ TEST_F(HloCseTest, NonscalarConstants) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( {common_constant1, common_constant2, uncommon_constant})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); + EXPECT_THAT(tuple, + op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); - EXPECT_EQ(uncommon_constant, tuple->operand(2)); - EXPECT_TRUE(tuple->operand(0) == common_constant1 || - tuple->operand(0) == common_constant2); + auto first_operand = tuple->operand(0); + EXPECT_THAT(first_operand, + ::testing::AnyOf(common_constant1, common_constant2)); + EXPECT_THAT(tuple, + op::Tuple(first_operand, first_operand, uncommon_constant)); } TEST_F(HloCseTest, IdenticalInstructions) { @@ -211,20 +216,19 @@ TEST_F(HloCseTest, IdenticalInstructions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); - EXPECT_NE(tuple->operand(1), tuple->operand(2)); - EXPECT_NE(tuple->operand(0), tuple->operand(2)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); - EXPECT_EQ(tuple->operand(1), tuple->operand(2)); + auto first_operand = tuple->operand(0); + EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2, exp3)); + EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand)); } TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { @@ -245,17 +249,17 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(&module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); } TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { @@ -276,17 +280,19 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); + auto first_operand = tuple->operand(0); + EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2)); + EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand)); } TEST_F(HloCseTest, IdenticalExpressions) { @@ -324,18 +330,19 @@ TEST_F(HloCseTest, IdenticalExpressions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(8, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); - EXPECT_EQ(HloOpcode::kAdd, tuple->operand(0)->opcode()); + auto operand = tuple->operand(0); + EXPECT_THAT(tuple, op::Tuple(operand, operand)); + EXPECT_THAT(operand, op::Add(op::Negate(), op::Exp())); } TEST_F(HloCseTest, DoNotCombineRng) { @@ -351,12 +358,16 @@ TEST_F(HloCseTest, DoNotCombineRng) { auto rng2 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); + builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, rng1, rng2)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(rng1, rng2)); + uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); @@ -364,11 +375,8 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kRng); - EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kRng); - EXPECT_NE(root->operand(0), root->operand(1)); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(rng1, rng2)); } // TODO(b/28245743): Handle impure functions correctly in CSE. @@ -376,7 +384,7 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); // rng_function is an impure function because it does RNG. HloComputation* rng_function = nullptr; @@ -412,17 +420,22 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { } EXPECT_EQ(4, computation->instruction_count()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(op::Map(), op::Map())); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMap); - EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kMap); - EXPECT_NE(root->operand(0), root->operand(1)); + root = computation->root_instruction(); + auto operand = root->operand(0)->operand(0); + EXPECT_THAT(operand, op::Map()); + EXPECT_THAT(root, op::Add(operand, operand)); } } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc new file mode 100644 index 00000000000..d1b87256445 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -0,0 +1,834 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +string HloLocation::ToString() const { + string index_str = + ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; + return StrCat(instruction->FullyQualifiedName(), index_str); +} + +std::ostream& operator<<(std::ostream& out, const HloLocation& location) { + out << location.ToString(); + return out; +} + +string HloUse::ToString() const { + string index_str = + ShapeUtil::IsTuple(instruction->operand(operand_number)->shape()) + ? (" " + operand_index.ToString()) + : ""; + return StrCat(instruction->FullyQualifiedName(), ", operand ", operand_number, + index_str); +} + +std::ostream& operator<<(std::ostream& out, const HloUse& use) { + out << use.ToString(); + return out; +} + +HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, + const ShapeIndex& index, bool is_phi) + : id_(id), is_phi_(is_phi) { + // The defining location is always the first element in the locations_ vector. + AddLocation(instruction, index); +} + +bool HloValue::operator==(const HloValue& other) const { + bool equal = instruction() == other.instruction() && index() == other.index(); + // If the values are equal they most both be phi (or non phi). + CHECK(!(equal && is_phi() != other.is_phi())); + return equal; +} + +bool HloValue::operator!=(const HloValue& other) const { + return !(*this == other); +} + +string HloValue::ToShortString() const { + string index_str = + ShapeUtil::IsTuple(instruction()->shape()) ? index().ToString() : ""; + return StrCat(is_phi_ ? "PHI " : "", instruction()->FullyQualifiedName(), + index_str); +} + +string HloValue::ToString(int indent) const { + string indentation(indent, ' '); + string out = StrCat(indentation, ToShortString(), ", locations:\n"); + for (const HloLocation& location : locations()) { + StrAppend(&out, indentation, " ", location.ToString(), "\n"); + } + StrAppend(&out, indentation, " uses:\n"); + for (const HloUse& use : uses()) { + StrAppend(&out, indentation, " ", use.ToString(), "\n"); + } + return out; +} + +void HloValue::AddLocation(HloInstruction* instruction, + const ShapeIndex& index) { + // The given location should not already exist in locations_. + for (const HloLocation& location : locations_) { + DCHECK(!(location.instruction == instruction && location.index == index)); + } + + locations_.push_back(HloLocation{instruction, index}); + + // Update uses. + for (HloInstruction* user : instruction->users()) { + for (int64 operand_number : user->OperandIndices(instruction)) { + if (!DoesNotUseOperandBuffer(instruction, index, user)) { + for (const HloUse& use : uses_) { + // Verify that this use does not already exist. + DCHECK(!(use.instruction == user && + use.operand_number == operand_number && + use.operand_index == index)); + } + + uses_.push_back(HloUse{user, operand_number, index}); + } + } + } + + // Update liveout status of this HloValue. + const HloModule& module = *instruction->parent()->parent(); + if (instruction == module.entry_computation()->root_instruction()) { + live_out_of_module_ = true; + } +} + +void HloValue::RemoveLocation(HloInstruction* instruction, + const ShapeIndex& index) { + // The defining location cannot be removed. + CHECK(!(instruction == this->instruction() && index == this->index())); + + int64 size_before = locations_.size(); + locations_.erase( + std::remove_if(locations_.begin(), locations_.end(), + [instruction, &index](const HloLocation& location) { + return location.instruction == instruction && + location.index == index; + }), + locations_.end()); + // Only a single location should have been removed. + CHECK_EQ(locations_.size(), size_before - 1); + + // Update uses which referred to this location. + uses_.erase(std::remove_if(uses_.begin(), uses_.end(), + [instruction, &index](const HloUse& use) { + return use.instruction->operand( + use.operand_number) == instruction && + use.operand_index == index; + }), + uses_.end()); + + const HloModule& module = *instruction->parent()->parent(); + if (instruction == module.entry_computation()->root_instruction()) { + // Value has been removed from a location in the entry root instruction. + // Check if the value is still live out of the module by walking all + // remaining locations. + live_out_of_module_ = false; + for (const HloLocation& location : locations()) { + if (location.instruction == + module.entry_computation()->root_instruction()) { + live_out_of_module_ = true; + break; + } + } + } +} + +std::ostream& operator<<(std::ostream& out, const HloValue& value) { + out << value.ToShortString(); + return out; +} + +void HloValueSet::SortAndUniquifyValues() { + std::sort(value_ids_.begin(), value_ids_.end()); + value_ids_.erase(std::unique(value_ids_.begin(), value_ids_.end()), + value_ids_.end()); +} + +string HloValueSet::ToString() const { + return StrCat("HloValueSet: ", tensorflow::str_util::Join(value_ids_, ", ")); +} + +/*static */ +HloValueSet HloValueSet::Union( + tensorflow::gtl::ArraySlice inputs) { + HloValueSet union_set; + for (const HloValueSet* input : inputs) { + for (HloValue::Id value_id : input->value_ids()) { + union_set.value_ids_.push_back(value_id); + } + } + union_set.SortAndUniquifyValues(); + return union_set; +} + +std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { + out << value_set.ToString(); + return out; +} + +InstructionValueSet InstructionValueSet::Union( + tensorflow::gtl::ArraySlice inputs) { + CHECK_GT(inputs.size(), 0); + for (int i = 1; i < inputs.size(); ++i) { + CHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); + } + InstructionValueSet union_set(inputs[0]->shape()); + union_set.ForEachMutableElement( + [&inputs](const ShapeIndex& index, HloValueSet* value_set) { + std::vector input_sets; + for (const InstructionValueSet* input : inputs) { + input_sets.push_back(&input->element(index)); + } + *value_set = HloValueSet::Union(input_sets); + }); + return union_set; +} + +std::ostream& operator<<(std::ostream& out, + const InstructionValueSet& instruction_value_set) { + out << instruction_value_set.ToString(); + return out; +} + +string InstructionValueSet::ToString() const { + string out = + StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); + ForEachElement([this, &out](const ShapeIndex& index, + const HloValueSet& value_set) { + StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); + }); + return out; +} + +HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form, + bool bitcast_defines_value) + : module_(module), + ssa_form_(ssa_form), + bitcast_defines_value_(bitcast_defines_value), + call_graph_(CallGraph::Build(module)) {} + +bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index) const { + const HloValueSet& value_set = GetValueSet(instruction, index); + if (value_set.value_ids().size() != 1) { + return false; + } + return GetValue(value_set.GetUniqueValueId()).instruction() == instruction; +} + +const HloValue& HloDataflowAnalysis::GetValueDefinedAt( + const HloInstruction* instruction, const ShapeIndex& index) const { + CHECK(ValueIsDefinedAt(instruction, index)); + return GetUniqueValueAt(instruction, index); +} + +HloValue& HloDataflowAnalysis::GetValueDefinedAt( + const HloInstruction* instruction, const ShapeIndex& index) { + CHECK(ValueIsDefinedAt(instruction, index)); + return GetUniqueValueAt(instruction, index); +} + +HloValue::Id HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, + const ShapeIndex& index, + bool is_phi) { + int64 value_id = next_value_id_++; + auto it_added = values_.emplace( + std::piecewise_construct, std::forward_as_tuple(value_id), + std::forward_as_tuple(value_id, instruction, index, is_phi)); + CHECK(it_added.second); + + // Clear the vector of values as it is now stale. It will be lazily + // reconstructed if needed when HloDataflowAnalysis::values() is called. + values_vector_.clear(); + + return value_id; +} + +void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { + values_.erase(value_id); + + // Clear the vector of values as it is now stale. It will be lazily + // reconstructed if needed when HloDataflowAnalysis::values() is called. + values_vector_.clear(); +} + +string HloDataflowAnalysis::ToString() const { + string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n"); + StrAppend(&out, " Instruction value sets:\n"); + for (const std::unique_ptr& computation : + module_->computations()) { + for (const std::unique_ptr& instruction : + computation->instructions()) { + StrAppend(&out, " ", instruction->FullyQualifiedName(), ":\n"); + if (ShapeUtil::IsTuple(instruction->shape())) { + GetInstructionValueSet(instruction.get()) + .ForEachElement([this, &instruction, &out]( + const ShapeIndex& index, + const HloValueSet& value_set) { + StrAppend(&out, " tuple index ", index.ToString(), ":\n"); + for (HloValue::Id value_id : value_set.value_ids()) { + StrAppend( + &out, " ", GetValue(value_id).ToShortString(), + ValueIsDefinedAt(instruction.get(), index) ? " (def)" : "", + "\n"); + } + }); + } else { + const HloValueSet& top_level_value_set = + GetValueSet(instruction.get(), /*index=*/{}); + for (HloValue::Id value_id : top_level_value_set.value_ids()) { + StrAppend(&out, " ", GetValue(value_id).ToShortString(), + ValueIsDefinedAt(instruction.get()) ? " (def)" : "", "\n"); + } + } + } + } + StrAppend(&out, " HloValues:\n"); + for (const auto& pair : values_) { + StrAppend(&out, pair.second.ToString(/*indent=*/4)); + } + return out; +} + +const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const { + return values_.at(value_id); +} + +HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) { + return values_.at(value_id); +} + +const HloValueSet& HloDataflowAnalysis::GetValueSet( + const HloInstruction* instruction, const ShapeIndex& index) const { + return GetInstructionValueSet(instruction).element(index); +} + +HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction, + const ShapeIndex& index) { + return *GetInstructionValueSet(instruction).mutable_element(index); +} + +const std::vector& HloDataflowAnalysis::values() const { + if (values_vector_.empty()) { + // Lazily construct vector of values. + values_vector_.reserve(values_.size()); + for (auto& pair : values_) { + values_vector_.push_back(&pair.second); + } + std::sort( + values_vector_.begin(), values_vector_.end(), + [](const HloValue* a, const HloValue* b) { return a->id() < b->id(); }); + } else { + CHECK_EQ(values_vector_.size(), values_.size()); + for (const HloValue* value : values_vector_) { + DCHECK(ContainsKey(values_, value->id())); + DCHECK(&GetValue(value->id()) == value); + } + } + return values_vector_; +} + +/* static */ +InstructionValueSet HloDataflowAnalysis::Phi( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice inputs, + bool skip_top_level) { + CHECK(ssa_form_); + + for (const InstructionValueSet* input : inputs) { + CHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); + } + InstructionValueSet new_value_set(instruction->shape()); + new_value_set.ForEachMutableElement( + [this, instruction, &inputs, skip_top_level](const ShapeIndex& index, + HloValueSet* value_set) { + // If we're skipping the top level, just copy over the existing + // HloValueSet. + if (skip_top_level && index.empty()) { + *value_set = GetInstructionValueSet(instruction).element(index); + return; + } + + // Identify the existing phi value at this index if it exists. + const HloValue* existing_phi_value = nullptr; + if (ValueIsDefinedAt(instruction, index) && + GetUniqueValueAt(instruction, index).is_phi()) { + existing_phi_value = &GetUniqueValueAt(instruction, index); + } + + // Construct a vector of unique value IDs of the inputs. + std::vector input_value_ids; + for (const InstructionValueSet* input : inputs) { + for (HloValue::Id value_id : input->element(index).value_ids()) { + input_value_ids.push_back(value_id); + } + } + std::sort(input_value_ids.begin(), input_value_ids.end()); + input_value_ids.erase( + std::unique(input_value_ids.begin(), input_value_ids.end()), + input_value_ids.end()); + + // Remove the existing phi value (if it exists). The phi can be its own + // input, for example, in while body parameters where the body passes + // through the parameter value. + if (existing_phi_value != nullptr) { + auto it = std::find(input_value_ids.begin(), input_value_ids.end(), + existing_phi_value->id()); + if (it != input_value_ids.end()) { + input_value_ids.erase(it); + } + } + + if (input_value_ids.size() <= 1) { + if (input_value_ids.size() == 1) { + *value_set = HloValueSet({input_value_ids[0]}); + } + if (existing_phi_value) { + // The merge point does not have multiple distinct inputs (which are + // not the phi value itself). Therefore there is no need to insert a + // phi value because there is a single reaching definition (or no + // reaching definition). + DeleteHloValue(existing_phi_value->id()); + } + } else if (input_value_ids.size() > 1) { + // Multiple distinct values reach this point. A phi value is + // necessary. + if (existing_phi_value) { + // A phi value already exists so reuse it in the new + // InstructionValueSet. + *value_set = HloValueSet({existing_phi_value->id()}); + } else { + // Create a new phi value. + *value_set = + HloValueSet({NewHloValue(instruction, index, /*is_phi=*/true)}); + } + } + }); + return new_value_set; +} + +void HloDataflowAnalysis::UpdateLocationsOfValuesAt( + HloInstruction* instruction, const InstructionValueSet& new_value_set, + const InstructionValueSet* prev_value_set) { + if (prev_value_set != nullptr) { + // Remove locations from the old value set. + prev_value_set->ForEachElement( + [this, instruction](const ShapeIndex& index, + const HloValueSet& value_set) { + for (HloValue::Id value_id : value_set.value_ids()) { + // HloValues in the previous value set may have been deleted. + if (!ContainsKey(values_, value_id)) { + continue; + } + // Don't remove the defining location of the value. + HloValue& value = GetValue(value_id); + if (instruction == value.instruction()) { + CHECK_EQ(index, value.index()); + } else { + value.RemoveLocation(instruction, index); + } + } + }); + } + // Add locations in the new value set. + new_value_set.ForEachElement( + [this, instruction](const ShapeIndex& index, + const HloValueSet& value_set) { + for (HloValue::Id value_id : value_set.value_ids()) { + HloValue& value = GetValue(value_id); + if (instruction == value.instruction()) { + CHECK_EQ(index, value.index()); + } else { + value.AddLocation(instruction, index); + } + } + }); +} + +InstructionValueSet HloDataflowAnalysis::RecomputeBitcastValueSet( + HloInstruction* bitcast) { + CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast); + if (bitcast_defines_value_) { + return GetInstructionValueSet(bitcast); + } else { + return GetInstructionValueSet(bitcast->operand(0)); + } +} + +InstructionValueSet HloDataflowAnalysis::RecomputeCopyValueSet( + HloInstruction* copy) { + CHECK_EQ(copy->opcode(), HloOpcode::kCopy); + InstructionValueSet new_value_set = GetInstructionValueSet(copy); + if (ShapeUtil::IsTuple(copy->shape())) { + for (int i = 0; i < ShapeUtil::TupleElementCount(copy->shape()); ++i) { + new_value_set.CopySubtreeFrom(GetInstructionValueSet(copy->operand(0)), + /*source_base_index=*/{i}, + /*target_base_index=*/{i}); + } + } + return new_value_set; +} + +InstructionValueSet HloDataflowAnalysis::RecomputeGetTupleElementValueSet( + HloInstruction* gte) { + CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); + InstructionValueSet new_value_set(gte->shape()); + new_value_set.CopySubtreeFrom(GetInstructionValueSet(gte->operand(0)), + /*source_base_index=*/{gte->tuple_index()}, + /*target_base_index=*/{}); + return new_value_set; +} + +InstructionValueSet HloDataflowAnalysis::RecomputeSelectValueSet( + HloInstruction* select) { + CHECK_EQ(select->opcode(), HloOpcode::kSelect); + std::vector inputs = { + &GetInstructionValueSet(select->operand(1)), + &GetInstructionValueSet(select->operand(2))}; + // A phi value is not defined at a kSelect instruction because kSelect does + // not create a new value. Rather it forwards a value from its operands. This + // contrasts with kWhile instruction (which does define a phi value) which has + // in-place update semantics. + InstructionValueSet new_value_set = InstructionValueSet::Union(inputs); + *new_value_set.mutable_element(/*index=*/{}) = + GetInstructionValueSet(select).element(/*index=*/{}); + return new_value_set; +} + +InstructionValueSet HloDataflowAnalysis::RecomputeTupleValueSet( + HloInstruction* tuple) { + CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); + InstructionValueSet new_value_set(tuple->shape()); + *new_value_set.mutable_element(/*index=*/{}) = + GetInstructionValueSet(tuple).element(/*index=*/{}); + for (int64 i = 0; i < tuple->operands().size(); ++i) { + new_value_set.CopySubtreeFrom(GetInstructionValueSet(tuple->operand(i)), + /*source_base_index=*/{}, + /*target_base_index=*/{i}); + } + return new_value_set; +} + +InstructionValueSet HloDataflowAnalysis::RecomputeWhileValueSet( + HloInstruction* xla_while) { + CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile); + std::vector inputs = { + &GetInstructionValueSet(xla_while->while_body()->root_instruction()), + &GetInstructionValueSet(xla_while->operand(0))}; + if (ssa_form_) { + return Phi(xla_while, inputs); + } else { + return InstructionValueSet::Union(inputs); + } +} + +void HloDataflowAnalysis::UpdateInstructionValueSet( + HloInstruction* instruction) { + // Recompute from operands. + InstructionValueSet& value_set = GetInstructionValueSet(instruction); + switch (instruction->opcode()) { + case HloOpcode::kBitcast: + value_set = RecomputeBitcastValueSet(instruction); + break; + case HloOpcode::kCopy: + value_set = RecomputeCopyValueSet(instruction); + break; + case HloOpcode::kGetTupleElement: + value_set = RecomputeGetTupleElementValueSet(instruction); + break; + case HloOpcode::kSelect: + value_set = RecomputeSelectValueSet(instruction); + break; + case HloOpcode::kTuple: + value_set = RecomputeTupleValueSet(instruction); + break; + case HloOpcode::kParameter: + value_set = RecomputeParameterValueSet(instruction); + break; + case HloOpcode::kCall: + // The output of a kCall instruction is exactly the output of the root of + // the subcomputation. + value_set = + GetInstructionValueSet(instruction->to_apply()->root_instruction()); + break; + case HloOpcode::kWhile: + value_set = RecomputeWhileValueSet(instruction); + break; + default: + // Instruction does not forward HloValues (it defines all values in its + // output). No update is necessary. + return; + } +} + +void HloDataflowAnalysis::UpdateInstructionsAndPropagate( + tensorflow::gtl::ArraySlice instructions) { + std::queue worklist; + for (HloInstruction* instruction : instructions) { + worklist.push(instruction); + } + + while (!worklist.empty()) { + HloInstruction* instruction = worklist.front(); + worklist.pop(); + + VLOG(3) << "Worklist top: " << instruction->name(); + VLOG(3) << ToString(); + + // Save old value for recomputing uses and live out. + InstructionValueSet old_value = GetInstructionValueSet(instruction); + UpdateInstructionValueSet(instruction); + + if (GetInstructionValueSet(instruction) == old_value) { + // No change to the instruction's value set. + VLOG(4) << "No change."; + continue; + } + + VLOG(4) << "New value set for " << instruction->name() << ": " + << GetInstructionValueSet(instruction); + VLOG(4) << "Previously: " << old_value; + + // Instruction value was updated. Add users to work list. + for (HloInstruction* user : instruction->users()) { + worklist.push(user); + + // If user calls a computation, then the respective parameter(s) of the + // computation need to be updated. + for (HloComputation* called_computation : user->called_computations()) { + for (int64 operand_number : user->OperandIndices(instruction)) { + worklist.push( + called_computation->parameter_instruction(operand_number)); + } + } + } + + // If instruction is a root instruction, then propagate out to any calling + // instruction and across any while backedge. + if (instruction == instruction->parent()->root_instruction()) { + const CallGraphNode& call_graph_node = + call_graph_->GetNode(instruction->parent()); + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kCall) { + worklist.push(callsite.instruction()); + } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + // Add the while itself, and the body and condition parameters. + worklist.push(callsite.instruction()); + worklist.push( + callsite.instruction()->while_body()->parameter_instruction(0)); + worklist.push( + callsite.instruction()->while_condition()->parameter_instruction( + 0)); + } + } + } + + // Update uses. First clear all of the old uses at the particular + // operands. Then add the new uses. There may be overlap between the old + // uses and new uses. + UpdateLocationsOfValuesAt(instruction, GetInstructionValueSet(instruction), + &old_value); + } +} + +InstructionValueSet HloDataflowAnalysis::RecomputeParameterValueSet( + HloInstruction* parameter) { + CHECK_EQ(parameter->opcode(), HloOpcode::kParameter); + const CallGraphNode& call_graph_node = + call_graph_->GetNode(parameter->parent()); + + // Subcomputations called in a parallel context (eg, map) do not have dataflow + // from the caller operands. + if (call_graph_node.context() == CallContext::kParallel || + call_graph_node.caller_callsites().empty()) { + return GetInstructionValueSet(parameter); + } + CHECK_EQ(call_graph_node.context(), CallContext::kSequential); + + std::vector inputs; + bool called_from_while = false; + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + inputs.push_back(&GetInstructionValueSet( + callsite.instruction()->operand(parameter->parameter_number()))); + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + // In a while instruction, the backedge is also a dataflow input to the + // parameter instruction. This code covers the case where the parameter is + // in the while body or the parameter is in the while condition. + inputs.push_back(&GetInstructionValueSet( + callsite.instruction()->while_body()->root_instruction())); + called_from_while = true; + } + } + + if (ssa_form_ && called_from_while) { + return Phi(parameter, inputs); + } else { + return InstructionValueSet::Union(inputs); + } +} + +const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( + const HloInstruction* instruction) const { + return value_sets_.at(instruction); +} + +InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( + const HloInstruction* instruction) { + return value_sets_.at(instruction); +} + +Status HloDataflowAnalysis::InitializeInstructionValueSets() { + for (const std::unique_ptr& computation : + module_->computations()) { + const CallGraphNode& call_graph_node = + call_graph_->GetNode(computation.get()); + for (const std::unique_ptr& instruction : + computation->instructions()) { + // Create an empty shape tree. + value_sets_.emplace(std::piecewise_construct, + std::forward_as_tuple(instruction.get()), + std::forward_as_tuple(instruction->shape())); + + // Lambda to set the value set to define all values in the output of the + // instruction. + auto define_all_values = [this, &instruction]() { + GetInstructionValueSet(instruction.get()) + .ForEachMutableElement([this, &instruction]( + const ShapeIndex& index, + HloValueSet* value_set) { + *value_set = HloValueSet({NewHloValue(instruction.get(), index)}); + }); + }; + + // Lambda to set the value set to define only the top-level buffer in the + // output of the instruction. Any other values flow from the operands of + // the instruction (or from cross-computation dataflow). + auto define_top_level_only = [this, &instruction]() { + GetValueSet(instruction.get(), /*index=*/{}) = + HloValueSet({NewHloValue(instruction.get(), /*index=*/{})}); + }; + + switch (instruction->opcode()) { + case HloOpcode::kBitcast: + if (bitcast_defines_value_) { + define_all_values(); + } + break; + case HloOpcode::kCall: + case HloOpcode::kWhile: + case HloOpcode::kGetTupleElement: + // These instructions define no values. The values in their output + // flow from their operands or from cross computation dataflow. + break; + case HloOpcode::kParameter: + if (call_graph_node.caller_callsites().empty() || + call_graph_node.context() == CallContext::kParallel) { + // Parameters of computations called in a parallel context (eg, map + // and reduce) as well as parameters of dead computations define all + // values in their output. Otherwise the values of the parameter + // come from the caller (eg, operands to the kCall instruction). + define_all_values(); + } else if (call_graph_node.context() == CallContext::kBoth) { + // We do not support a subcomputation that is called from both a + // parallel and sequential context. In this case, the parameter + // would both define a value and propagate a value from its + // caller. This limitation is not really a problem because the call + // graph is typically flattened. + return Unimplemented( + "Computation %s is called in both a parallel (eg, kMap) and " + "sequential (eg, kCall) context", + computation->name().c_str()); + } + break; + case HloOpcode::kCopy: + case HloOpcode::kSelect: + case HloOpcode::kTuple: + // These instructions only define their top-level values. Any other + // values flow from their operands. + define_top_level_only(); + break; + default: + define_all_values(); + break; + } + UpdateLocationsOfValuesAt(instruction.get(), + GetInstructionValueSet(instruction.get())); + } + } + return Status::OK(); +} + +/* static */ +StatusOr> HloDataflowAnalysis::Run( + HloModule* module, bool ssa_form, bool bitcast_defines_value) { + VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name(); + XLA_VLOG_LINES(2, module->ToString()); + + auto dataflow_analysis = WrapUnique( + new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); + + TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); + + // Construct list of all instructions to initialize the worklist to propagate + // the data flow. For efficiency sort the instruction in post order so + // producers appear before consumers. + std::vector all_instructions; + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + all_instructions.push_back(instruction); + } + } + dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions); + + VLOG(1) << dataflow_analysis->ToString(); + return std::move(dataflow_analysis); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h new file mode 100644 index 00000000000..2f9b0a64be5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -0,0 +1,399 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Analysis for determining the possible set of values for all locations +// (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped +// tracking values across computation boundaries. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Abstraction which identifies a specific point in the XLA graph. An +// HloLocation specifies a ShapeIndex within the output of a specific +// instruction. +struct HloLocation { + HloInstruction* instruction; + ShapeIndex index; + + string ToString() const; + + bool operator==(const HloLocation& other) const { + return instruction == other.instruction && index == other.index; + } + bool operator!=(const HloLocation& other) const { return !(*this == other); } +}; + +std::ostream& operator<<(std::ostream& out, const HloLocation& location); + +// Defines a single use of an HLO value. +struct HloUse { + // Instruction at which the value is used. + HloInstruction* instruction; + + // The operand number in which the value is appears. + int64 operand_number; + + // The shape index within the operand in which the value appears. + ShapeIndex operand_index; + + string ToString() const; + + bool operator==(const HloUse& other) const { + return instruction == other.instruction && + operand_number == other.operand_number && + operand_index == other.operand_index; + } + + bool operator!=(const HloUse& other) const { return !(*this == other); } +}; + +std::ostream& operator<<(std::ostream& out, const HloUse& use); + +// Class describing a value used by the dataflow analysis. XLA arrays are +// trivially a single HloValue. Tuples are made up of more than one HloValue: an +// HloValue for the pointer vector, and an HloValue for each child element. +// +// Every HloValue is defined by a particular instruction and most instructions +// define only a single HloValue. Instructions which define a single HloValue +// include array-shaped instructions such as Add but also includes Tuple-shaped +// instructions such as Tuple. The Tuple instruction defines a single HloValue +// which is a vector of pointers to the values containing the Tuple +// instruction's operands. Though the result of the Tuple instruction includes +// multiple values only the top-level HloValue (the vector of pointers) is +// defined by the Tuple instruction. The values containing the tuple elements +// are defined by earlier instructions, usually the operands of the Tuple +// instruction. +// +// Instructions which construct both the tuple *and* the tuple elements define +// more than one HloValue. This includes (at least) tuple-shaped Constant, +// Parameter, Infeed and While instructions. These tuple-shaped instructions do +// not assemble a tuple from existing HloValues like the Tuple instruction does, +// but rather define all the HloValues in the tuple. +class HloValue { + public: + using Id = int64; + + // Construct an HloValue defined by 'instruction' at shape index 'index'. If + // is_phi is true, then this value is a phi value, for example, at the + // parameter of a while body computation. Phi values are only used in the SSA + // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). + HloValue(HloValue::Id id, HloInstruction* instruction, + const ShapeIndex& index, bool is_phi = false); + + // Return a unique identifier for this HloValue. This value is used for stable + // sorting and iteration + Id id() const { return id_; } + + // Returns whether this value is a phi value. + bool is_phi() const { return is_phi_; } + + // Return the location where this value is defined. + const HloLocation& DefinitionLocation() const { return locations_[0]; } + + // Return the instruction which defines this HloValue. + HloInstruction* instruction() const { + return DefinitionLocation().instruction; + } + + // Return the shape index at which this HloValue is defined in the output of + // instruction(). + const ShapeIndex& index() const { return DefinitionLocation().index; } + + // Add or remove a location at which the HloValue appears. The definition + // location can not be removed. The uses of the HloValue are updated. + void AddLocation(HloInstruction* instruction, const ShapeIndex& index); + void RemoveLocation(HloInstruction* instruction, const ShapeIndex& index); + + // Return all locations of the HloValue in the module. + const std::vector& locations() const { return locations_; } + + // Return all uses of the HloValue. + const std::vector& uses() const { return uses_; } + + // Set/get whether this HloValue is live out of the module. + bool live_out_of_module() const { return live_out_of_module_; } + + bool operator==(const HloValue& other) const; + bool operator!=(const HloValue& other) const; + + // Return a single-line string representation of the value. + string ToShortString() const; + + string ToString(int indent = 0) const; + + private: + // Unique identifier for this HloValue. Used for stable sorting and iteration. + const Id id_; + + // Whether this instruction is a phi value. + const bool is_phi_; + + // The set of locations of this HloValue. The first element is always the + // location of the definition. + std::vector locations_; + + // The set of uses of this HloValue. + std::vector uses_; + + // Whether this value is live out of the HLO module. + bool live_out_of_module_ = false; +}; + +std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); + +// A class representing the possible set of HloValues at a particular point +// (shape index in the output of an instruction) in the XLA graph. This set +// contains the set of reaching HloValue definitions. For a simple array-shaped +// instruction like Add, the HloValueSet of the top-level of the instruction's +// output trivially contains only the HloValue defined by the instruction. For +// instructions which have non-trivial dataflow such as Tuple or Select, the +// HloValueSets of the instruction's output contains one or more HloValues +// defined by the instruction's operands or defined further up in the XLA graph. +class HloValueSet { + public: + HloValueSet() = default; + + explicit HloValueSet(tensorflow::gtl::ArraySlice value_ids) + : value_ids_(value_ids.begin(), value_ids.end()) { + SortAndUniquifyValues(); + } + + // Return the union of the given HloValueSets. + static HloValueSet Union( + tensorflow::gtl::ArraySlice inputs); + + // Return the vector of the IDs of all HloValues in the set. Values in the + // vector are unique and sorted. + const std::vector& value_ids() const { return value_ids_; } + + // Return the unique HLO value in the set. CHECKs if the set does not contain + // exactly one value. + HloValue::Id GetUniqueValueId() const { + CHECK_EQ(value_ids().size(), 1); + return value_ids()[0]; + } + + bool operator==(const HloValueSet& other) const { + return value_ids() == other.value_ids(); + } + bool operator!=(const HloValueSet& other) const { return !(*this == other); } + + string ToString() const; + + private: + // Sorts value_ and removes duplicates. This should be called after adding any + // elements to values_. + void SortAndUniquifyValues(); + + // HloValues sorted by HloValue::Id. + std::vector value_ids_; +}; + +std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); + +// A class collecting the HloValues which might be contained in the output of +// an HLO instruction. For array-shaped instructions, an InstructionValueSet +// trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets +// hold multiple HloValueSets. +class InstructionValueSet : public ShapeTree { + public: + InstructionValueSet(const Shape& shape) : ShapeTree(shape) {} + + // Return the union of the given InstructionValueSets. + static InstructionValueSet Union( + tensorflow::gtl::ArraySlice inputs); + + string ToString() const; +}; + +std::ostream& operator<<(std::ostream& out, + const InstructionValueSet& instruction_value_set); + +// Analysis which identifies all HLO values and their uses in an HLO module. +class HloDataflowAnalysis { + public: + // Run dataflow analysis on the given module. Parameters: + // + // ssa_form : If true then new values are defined at the merge points of + // kWhile instructions. Abusing nomenclature somewhat, we call these "phi + // values". The merge is formed by the init value and loop backedge. The + // SSA form is minimal in that a new phi value is defined only if the + // merge point is reachable by multiple different values. The SSA form is + // also in loop-closed form in that no values defined inside of a loop + // (while body) is used outside of the loop. + // + // If ssa_form is false, then merge points do not define new + // values. Rather, the HloValueSet for the merge point contains the union + // of the merged HloValues. + // + // bitcast_defines_value : If true then the Bitcast HLO instruction defines + // a new HLO value in the analysis. If false then Bitcast forwards the + // value of its operand. + static StatusOr> Run( + HloModule* module, bool ssa_form = false, + bool bitcast_defines_value = false); + + // Returns true if 'instruction' defines an HLO value at the given shape index + // of its output. + bool ValueIsDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + + // Return the HloValue defined by 'instruction' at the given shape index of + // its output. + // + // Precondition: ValueIsDefinedAt is true for this instruction and index. + const HloValue& GetValueDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + HloValue& GetValueDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index = {}); + + // Return the InstructionValueSet for the given instruction. + const InstructionValueSet& GetInstructionValueSet( + const HloInstruction* instruction) const; + InstructionValueSet& GetInstructionValueSet( + const HloInstruction* instruction); + + // Return the HloValueSet for the given instruction at the given index. + const HloValueSet& GetValueSet(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + HloValueSet& GetValueSet(const HloInstruction* instruction, + const ShapeIndex& index = {}); + + // Return the unique value in the HloValueSet at the given instruction and + // shape index. CHECKs if the value set does not contain a exactly one value. + const HloValue& GetUniqueValueAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + return GetValue(GetValueSet(instruction, index).GetUniqueValueId()); + } + HloValue& GetUniqueValueAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) { + return GetValue(GetValueSet(instruction, index).GetUniqueValueId()); + } + + // Return the HloValue with the given Id. + const HloValue& GetValue(HloValue::Id value_id) const; + HloValue& GetValue(HloValue::Id value_id); + + // Return the total number of HloValues. + int64 value_count() const { return values_.size(); } + + // Return a vector of all HloValues stabily sorted by HloValue::Id. This + // vector is lazily computed. Mutating operations on HloDataflowAnalysis may + // invalidate the underlying vector requiring recomputation. + const std::vector& values() const; + + string ToString() const; + + protected: + HloDataflowAnalysis(HloModule* module, bool ssa_form, + bool bitcast_defines_value = false); + + // Creates a new HloValue defined at the given instruction and shape index and + // return its ID. + HloValue::Id NewHloValue(HloInstruction* instruction, const ShapeIndex& index, + bool is_phi = false); + + // Delete the HloValue with the given ID. + void DeleteHloValue(HloValue::Id value_id); + + // Constructs and initializes the InstructionValueSets of all instructions to + // contain exactly the HloValues defined by each instruction. These values can + // then propagated throughout the HLO graph by calling + // UpdateInstructionsAndPropagate. + Status InitializeInstructionValueSets(); + + // Updates the value set of the given instruction based on the values flowing + // into the instruction (operands and cross-computation dataflow). + void UpdateInstructionValueSet(HloInstruction* instruction); + + // Recomputes and returns the value set for the given parameter instruction. + InstructionValueSet RecomputeBitcastValueSet(HloInstruction* bitcast); + InstructionValueSet RecomputeCopyValueSet(HloInstruction* copy); + InstructionValueSet RecomputeGetTupleElementValueSet(HloInstruction* gte); + InstructionValueSet RecomputeParameterValueSet(HloInstruction* parameter); + InstructionValueSet RecomputeSelectValueSet(HloInstruction* select); + InstructionValueSet RecomputeTupleValueSet(HloInstruction* tuple); + InstructionValueSet RecomputeWhileValueSet(HloInstruction* xla_while); + + // Update the value sets of the given instructions and propagate the + // changes to fixed point. + void UpdateInstructionsAndPropagate( + tensorflow::gtl::ArraySlice instructions); + + // Return the result of the SSA Phi function applied to the given inputs at + // the given instruction. If skip_top_level is true, then the top level of the + // value set of 'instruction' is not modified. + InstructionValueSet Phi( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice inputs, + bool skip_top_level = false); + + // Updates the locations of the HloValues in the output of the given + // instruction. This should be called after the instruction value set of + // 'instruction' has been changed. 'prev_value_set' must point to the previous + // state of the value set prior to the change. 'prev_value_set' may be null if + // this is the first time locations are being computed. The previous state is + // necessary to efficiently remove locations which have been eliminated due to + // changes in the instructions' InstructionValueSet. + void UpdateLocationsOfValuesAt( + HloInstruction* instruction, const InstructionValueSet& new_value_set, + const InstructionValueSet* prev_value_set = nullptr); + + HloModule* const module_; + const bool ssa_form_; + const bool bitcast_defines_value_; + + std::unique_ptr call_graph_; + + // The map of all HloValues in the module. + std::unordered_map values_; + + // A map from instruction to InstructionValueSet. + std::unordered_map value_sets_; + + // A lazily constructed vector containing all HloValues sorted by + // HloValue::Id. + mutable std::vector values_vector_; + + // The Id to use for the next HloValue. + HloValue::Id next_value_id_ = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc new file mode 100644 index 00000000000..21344af5f22 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -0,0 +1,1134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +using ::testing::UnorderedElementsAre; + +// Test is parameterized on a bool which is whether the dataflow analysis is +// performed with SSA form. +class HloDataflowAnalysisTest : public HloTestBase, + public ::testing::WithParamInterface { + protected: + HloDataflowAnalysisTest() : module_(TestName()) {} + + // Run dataflow analysis on the member module. For convenience returns a + // reference to the generated analysis stored in analysis_. + const HloDataflowAnalysis& RunAnalysis(bool ssa_form, + bool bitcast_defines_value = false) { + analysis_ = + HloDataflowAnalysis::Run(&module_, ssa_form, bitcast_defines_value) + .ConsumeValueOrDie(); + return *analysis_; + } + + // Return a vector of the HloValues at the given program location. + std::vector HloValuesAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) { + CHECK(analysis_ != nullptr); + std::vector values; + for (HloValue::Id value_id : + analysis_->GetValueSet(instruction, index).value_ids()) { + values.push_back(analysis_->GetValue(value_id)); + } + return values; + } + + HloModule module_; + std::unique_ptr analysis_; + + const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_P(HloDataflowAnalysisTest, BinaryOperation) { + // Test the dataflow for a simple binary operation (Add). + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, constant1, constant2)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + // Each instruction should define a single value. + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + + // Verify the locations of the values. These locations are all trivial because + // there are no instructions which forward values. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).locations(), + UnorderedElementsAre(HloLocation{constant1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).locations(), + UnorderedElementsAre(HloLocation{constant2, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(add).locations(), + UnorderedElementsAre(HloLocation{add, {}})); + + // Verify the uses of the values. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{add, 1, {}})); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).uses().empty()); + + // Verify liveout values from the module. + EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); +} + +TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { + // Verify the dataflow through a Tuple and GetTupleElement instructions. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + // The two params, tuple, and add should each define one value. + EXPECT_EQ(analysis.values().size(), 4); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(param0)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(param1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(gte0)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + + // Verify the locations of the values. + EXPECT_THAT( + analysis.GetValueDefinedAt(param0).locations(), + UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}}, + HloLocation{gte0, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(param1).locations(), + UnorderedElementsAre(HloLocation{param1, {}}, HloLocation{tuple, {1}}, + HloLocation{gte1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple).locations(), + UnorderedElementsAre(HloLocation{tuple, {}})); + + // Verify uses. Of interest is that a GetTupleElement instruction is only a + // use of the top-level value in the tuple operand. + EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(), + UnorderedElementsAre(HloUse{tuple, 0, {}}, HloUse{add, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(), + UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{add, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}})); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); +} + +TEST_P(HloDataflowAnalysisTest, NestedTuple) { + // Verify the dataflow through a nested tuple of the following form for two + // constants %constant1 and %constant2: + // + // %nested_tuple = {{%constant1, %constant2}, + // {%constant1, %constant2}, + // %constant1} + // + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto nested_tuple = builder.AddInstruction( + HloInstruction::CreateTuple({tuple, tuple, constant1})); + auto gte_tuple = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1)); + auto gte_out = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 4); + + // Verify locations and uses. + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).locations(), + UnorderedElementsAre( + HloLocation{constant1, {}}, HloLocation{tuple, {0}}, + HloLocation{nested_tuple, {0, 0}}, HloLocation{nested_tuple, {1, 0}}, + HloLocation{nested_tuple, {2}}, HloLocation{gte_tuple, {0}}, + HloLocation{gte_out, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre( + HloUse{tuple, 0, {}}, HloUse{nested_tuple, 0, {0}}, + HloUse{nested_tuple, 1, {0}}, HloUse{nested_tuple, 2, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{nested_tuple, 0, {1}}, + HloUse{nested_tuple, 1, {1}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{nested_tuple, 0, {}}, + HloUse{nested_tuple, 1, {}}, + HloUse{gte_out, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{gte_tuple, 0, {}})); + + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); + EXPECT_FALSE( + analysis.GetValueDefinedAt(tuple, /*index=*/{}).live_out_of_module()); + EXPECT_FALSE(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}) + .live_out_of_module()); +} + +TEST_P(HloDataflowAnalysisTest, SingleCall) { + // Test a single call of a subcomputation. The subcomputation adds its two + // array-shaped parameters. + auto subbuilder = HloComputation::Builder("Subcomputation"); + auto subparam0 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto subparam1 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); + HloComputation* called_computation = + module_.AddEmbeddedComputation(subbuilder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto call = builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {constant1, constant2}, called_computation)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 3); + + // The parameters of the subcomputation and the call instruction itself should + // not define values. Their values flow from elsewhere. + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(call)); + + EXPECT_EQ(analysis.GetUniqueValueAt(subparam0), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(subparam1), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add)); + + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call, 1, {}})); + + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); +} + +TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { + // Test a subcomputation which is called twice with identical values. + auto subbuilder = HloComputation::Builder("Subcomputation"); + auto subparam0 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto subparam1 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); + HloComputation* called_computation = + module_.AddEmbeddedComputation(subbuilder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto call1 = builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {constant1, constant2}, called_computation)); + auto call2 = builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {constant1, constant2}, called_computation)); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kSubtract, call1, call2)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 4); + + // Definitions should be identical to the single callsite case. + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(call1)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(call2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); + + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call1, 0, {}}, + HloUse{call2, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call1, 1, {}}, + HloUse{call2, 1, {}})); + // The Add from the subcomputation is used as both operands of the Subtract. + EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), + UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); + + EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); +} + +TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { + // Test a subcomputation which is called twice with different argument values. + auto subbuilder = HloComputation::Builder("Subcomputation"); + auto subparam0 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto subparam1 = subbuilder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); + HloComputation* called_computation = + module_.AddEmbeddedComputation(subbuilder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto call1 = builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {constant1, constant2}, called_computation)); + auto call2 = builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {call1, constant2}, called_computation)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(call1)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(call2)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0)); + + EXPECT_THAT(HloValuesAt(subparam0), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(add))); + EXPECT_THAT(HloValuesAt(subparam1), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant2))); + + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); +} + +TEST_P(HloDataflowAnalysisTest, NestedCalls) { + // Test a module with nested computations. HLO is: + // + // F32[] inner_computation(F32[] %param0, F32[] %param1): + // %add = Add(%param0, %param1) + // + // F32[] outer_computation((F32[] %param0, F32[] %param1): + // ;; Note that parameters are interchanged in the call. + // %nested_call = Call(inner_computation, {%param1, %param0}) + // + // F32[] entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %call = Call(outer_computation, {%constant1, %constant2}) + // + auto inner_builder = HloComputation::Builder("InnerComputation"); + auto inner_param0 = inner_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto inner_param1 = inner_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1)); + HloComputation* inner_computation = + module_.AddEmbeddedComputation(inner_builder.Build()); + + auto outer_builder = HloComputation::Builder("OuterComputation"); + auto outer_param0 = outer_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto outer_param1 = outer_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + // Swizzle parameters. + auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {outer_param1, outer_param0}, inner_computation)); + HloComputation* outer_computation = + module_.AddEmbeddedComputation(outer_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto call = builder.AddInstruction(HloInstruction::CreateCall( + scalar_shape_, {constant1, constant2}, outer_computation)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + // Only three values should be defined. Most instructions just pass through + // their operand values. + EXPECT_EQ(analysis.values().size(), 3); + + // Verify that the uses of the constants are properly swizzled by parameter + // permutation in nested_call. + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}}, + HloUse{add, 1, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}}, + HloUse{add, 0, {}})); + + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); +} + +TEST_P(HloDataflowAnalysisTest, SingleWhile) { + // Test a simple single while instruction. The while body includes a + // pass-through value. HLO: + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %tuple = Tuple(%constant1, %constant2) + // return While(%tuple, body, condition) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + auto body_tuple = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + + // Condition computation trivially returns a constant "false". + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + if (ssa_form) { + // Element 0 of the tuple passed through the body so no phi value is + // defined. + EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0})); + + // Element 1 of the tuple should be a phi value. + EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1})); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi()); + EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1})); + EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi()); + EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); + EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi()); + + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{tuple, 0, {}}, + HloUse{xla_while, 0, {0}}, + HloUse{body_tuple, 0, {}})); + + // Constant1 passes through the body and out of the module. + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) + .live_out_of_module()); + } else { + // While instruction and subcomputation parameters should not define values + // in non-ssa form. + EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); + + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); + } +} + +TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { + // Test sequential while instructions. The while body includes a + // pass-through value. HLO: + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %tuple = Tuple(%constant1, %constant2) + // %while0 = While(%tuple, body, condition) + // %while1 = While(%while0, body, condition) + // return While(%while1, body, condition) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while0 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); + auto xla_while1 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); + auto xla_while2 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + // Element 0 is passed through all the while instructions and out of the + // module.. + EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}), + analysis.GetValueDefinedAt(constant1)); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); +} + +TEST_P(HloDataflowAnalysisTest, NestedWhiles) { + // Test nested while instructions. The inner body passes through element 0 of + // its parameter, and the outer body passes through element 1. HLO: + // + // inner_body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // outer_body((F32[], F32[]) %tuple_param): + // %negate = Negate(%tuple_param{0}) + // %tuple = Tuple(%negate, %tuple_param{1}) + // return While(%tuple, inner_body, condition) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %tuple = Tuple(%constant1, %constant2) + // return While(%tuple, outer_body, condition) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + // Element 0 passes transparently through the body. + auto inner_builder = HloComputation::Builder("inner_body"); + auto inner_param = inner_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto inner_element_0 = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0)); + auto inner_element_1 = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1)); + auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1)); + inner_builder.AddInstruction( + HloInstruction::CreateTuple({inner_element_0, add})); + HloComputation* inner_body = + module_.AddEmbeddedComputation(inner_builder.Build()); + + // Element 1 passes transparently through the body. + auto outer_builder = HloComputation::Builder("outer_body"); + auto outer_param = outer_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto outer_element_0 = outer_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0)); + auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, outer_element_0)); + auto outer_element_1 = outer_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1)); + auto outer_tuple = outer_builder.AddInstruction( + HloInstruction::CreateTuple({negate, outer_element_1})); + auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, condition, inner_body, outer_tuple)); + HloComputation* outer_body = + module_.AddEmbeddedComputation(outer_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto entry_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(negate))); + if (ssa_form) { + EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1})); + EXPECT_TRUE( + analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi()); + + // Element 0 of the nested while is %negate. + EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0})); + EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(negate))); + // Element 1 is a phi value (join of %add and %constant2). + EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1})); + EXPECT_TRUE( + analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{0})); + EXPECT_TRUE( + analysis.GetValueDefinedAt(entry_while, /*index=*/{0}).is_phi()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{1})); + EXPECT_TRUE( + analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi()); + } else { + EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(constant2))); + + EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(negate))); + EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{1}), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(constant2))); + + EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(negate), + analysis.GetValueDefinedAt(constant1))); + EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{1}), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(constant2))); + } +} + +TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { + // Test a while instruction with a body which permutes it's tuple parameter + // elements. HLO: + // + // body((F32[], F32[]) %tuple_param): + // return Tuple(%tuple_param{1}, %tuple_param{0}) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %tuple = Tuple(%constant1, %constant2) + // return While(%tuple, body, condition) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_1, body_element_0})); + HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + if (ssa_form) { + // Element 0 and 1 in the while should both be phi values. + EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0})); + EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi()); + EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1})); + EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi()); + EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1})); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0})); + EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi()); + EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); + EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi()); + + EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{}) + .live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}) + .live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) + .live_out_of_module()); + } else { + // Elements 0 and 1 have both constants as reaching definitions. + EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); + EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); + } +} + +TEST_P(HloDataflowAnalysisTest, ArraySelect) { + // Test a kSelect of an array value. + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); + + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(select)); + EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(select).live_out_of_module()); +} + +TEST_P(HloDataflowAnalysisTest, TupleSelect) { + // Test a kSelect of a tuple value. Non-top-level element flow through the + // instruction. + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + auto constant4 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + auto tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({constant1})); + auto tuple2 = + builder.AddInstruction(HloInstruction::CreateTuple({constant2})); + auto tuple3 = + builder.AddInstruction(HloInstruction::CreateTuple({constant3})); + auto tuple4 = + builder.AddInstruction(HloInstruction::CreateTuple({constant4})); + const Shape tuple_shape = tuple1->shape(); + auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); + auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); + auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( + tuple_shape, HloOpcode::kSelect, pred, select12, select34)); + + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + // Top-level value is always defined by a kSelect. + EXPECT_TRUE(analysis.ValueIsDefinedAt(select11)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(select12)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(select34)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(select1234)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(select11, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(select12, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(select34, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(select1234, /*index=*/{0})); + + EXPECT_THAT(HloValuesAt(select11, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1))); + EXPECT_THAT(HloValuesAt(select12, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); + EXPECT_THAT(HloValuesAt(select34, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant3), + analysis.GetValueDefinedAt(constant4))); + EXPECT_THAT(HloValuesAt(select1234, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2), + analysis.GetValueDefinedAt(constant3), + analysis.GetValueDefinedAt(constant4))); + + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{tuple1, 0, {}}, HloUse{select11, 1, {0}}, + HloUse{select11, 2, {0}}, HloUse{select12, 1, {0}}, + HloUse{select1234, 1, {0}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{tuple2, 0, {}}, HloUse{select12, 2, {0}}, + HloUse{select1234, 1, {0}})); +} + +TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { + // Test kSelect of a nested tuple. + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + auto constant4 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + auto constant5 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0))); + auto inner_tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({constant2, constant3})); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, inner_tuple1})); + auto inner_tuple2 = builder.AddInstruction( + HloInstruction::CreateTuple({constant5, constant3})); + auto tuple2 = builder.AddInstruction( + HloInstruction::CreateTuple({constant4, inner_tuple2})); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(select)); + + EXPECT_THAT(HloValuesAt(select, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant4))); + EXPECT_THAT(HloValuesAt(select, /*index=*/{1}), + UnorderedElementsAre(analysis.GetValueDefinedAt(inner_tuple1), + analysis.GetValueDefinedAt(inner_tuple2))); + EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant2), + analysis.GetValueDefinedAt(constant5))); + EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant3))); +} + +TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { + // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %constant1 = Constant(1.0) + // %constant2 = Constant(2.0) + // %constant3 = Constant(3.0) + // %tuple1 = Tuple(%constant1) + // %tuple2 = Tuple(%constant2) + // %select = Select(%tuple1, %tuple2) + // %gte = GetTupleElement(%select, 0) + // %tuple = Tuple(%gte, %constant3) + // return While(%tuple, body, condition) + // + auto builder = HloComputation::Builder(TestName()); + + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_.AddEmbeddedComputation(cond_builder.Build()); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + auto tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({constant1})); + auto tuple2 = + builder.AddInstruction(HloInstruction::CreateTuple({constant2})); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + auto gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({gte, constant3})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple)); + + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + if (ssa_form) { + EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi()); + EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1})); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi()); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(select, /*index=*/{0})); + + EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); + EXPECT_FALSE(analysis.GetValueDefinedAt(constant3).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) + .live_out_of_module()); + } else { + EXPECT_THAT(HloValuesAt(gte), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); + EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); + EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(constant3))); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant3).live_out_of_module()); + } +} + +TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { + // Test the bitcast_defines_value flag to the dataflow analysis. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kBitcast, constant)); + + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + { + const HloDataflowAnalysis& analysis = + RunAnalysis(ssa_form, /*bitcast_defines_value=*/true); + + EXPECT_EQ(analysis.values().size(), 2); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(bitcast)); + EXPECT_FALSE(analysis.GetValueDefinedAt(constant).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(bitcast).live_out_of_module()); + } + { + const HloDataflowAnalysis& analysis = + RunAnalysis(ssa_form, /*bitcast_defines_value=*/false); + EXPECT_EQ(analysis.values().size(), 1); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(bitcast)); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module()); + } +} + +TEST_P(HloDataflowAnalysisTest, TupleCopy) { + // Test that a tuple-shaped copy only copies (defines) the top-level value. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple)); + module_.AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 4); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(param0)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(param1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(copy, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{1})); + + EXPECT_THAT(HloValuesAt(copy, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(param0))); + EXPECT_THAT(HloValuesAt(copy, /*index=*/{1}), + UnorderedElementsAre(analysis.GetValueDefinedAt(param1))); + EXPECT_TRUE( + analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); +} + +INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, + HloDataflowAnalysisTest, + ::testing::Values(false, true)); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index fdfbbf8baf6..3755b9e4c00 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -52,7 +52,7 @@ StatusOr HloDCE::Run(HloModule* module) { for (auto& instruction : computation->instructions()) { if (instruction->user_count() == 0 && live_instructions.count(instruction.get()) == 0 && - HloComputation::IsRemovable(instruction->opcode())) { + computation->IsRemovable(instruction.get())) { dead_roots.push_back(instruction.get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index dcd9e00c56c..10cd7ca7c09 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -50,7 +51,7 @@ TEST_F(HloDceTest, NoDeadCode) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); @@ -80,7 +81,7 @@ TEST_F(HloDceTest, DeadParameters) { builder.AddInstruction(HloInstruction::CreateUnary( live_param->shape(), HloOpcode::kNegate, live_param)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); @@ -93,5 +94,69 @@ TEST_F(HloDceTest, DeadParameters) { EXPECT_EQ(0, dead_param1->user_count()); } +TEST_F(HloDceTest, ControlDependencies) { + // Verify that instructions with control dependencies are not removed. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + + // Create two dead instructions: a negate and an add. + auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary( + constant1->shape(), HloOpcode::kNegate, constant1)); + auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + // Create the same two instructions again, but these will have a control + // dependency added. + auto dead_negate_with_control_dep = + builder.AddInstruction(HloInstruction::CreateUnary( + constant1->shape(), HloOpcode::kNegate, constant1)); + auto dead_add_with_control_dep = + builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + // Create a root so the previously added instruction is dead. + builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Add a control dependency between two instructions. + TF_ASSERT_OK(dead_negate_with_control_dep->AddControlDependencyTo( + dead_add_with_control_dep)); + + // Returns whether the given instruction exists in the test computation. + auto has_instruction = [computation](const HloInstruction* instruction) { + for (auto& inst : computation->instructions()) { + if (inst.get() == instruction) { + return true; + } + } + return false; + }; + + EXPECT_EQ(7, computation->instruction_count()); + EXPECT_TRUE(has_instruction(dead_negate)); + EXPECT_TRUE(has_instruction(dead_add)); + EXPECT_TRUE(has_instruction(dead_negate_with_control_dep)); + EXPECT_TRUE(has_instruction(dead_add_with_control_dep)); + + HloDCE dce; + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(5, computation->instruction_count()); + EXPECT_FALSE(has_instruction(dead_negate)); + EXPECT_FALSE(has_instruction(dead_add)); + EXPECT_TRUE(has_instruction(dead_negate_with_control_dep)); + EXPECT_TRUE(has_instruction(dead_add_with_control_dep)); +} + } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc new file mode 100644 index 00000000000..3e7f5b1f3d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -0,0 +1,791 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +namespace { + +template +StatusOr> Compare(const Shape& shape, HloOpcode opcode, + const Literal& lhs_literal, + const Literal& rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el != rhs_el; + }; + break; + case HloOpcode::kGe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el >= rhs_el; + }; + break; + case HloOpcode::kGt: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el > rhs_el; + }; + break; + case HloOpcode::kLe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el <= rhs_el; + }; + break; + case HloOpcode::kLt: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el < rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + auto result = LiteralUtil::CreateFromShape(shape); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return compare_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index)); + })); + + return std::move(result); +} + +template +StatusOr> ElementWiseUnaryOpImpl( + HloInstruction* instruction, + const std::function& unary_op, + const Literal& operand_literal) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + auto result = LiteralUtil::CreateFromShape(shape); + + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op( + LiteralUtil::Get(operand_literal, multi_index)); + })); + return std::move(result); +} + +} // namespace + +template +class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { + public: + explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} + + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", + HloOpcodeString(hlo_instruction->opcode()).c_str()); + }; + + // TODO(b/35950897): many of the stl functions used in the handlers are not + // overloaded for every XLA primitive types. + + template ::value>::type* = + nullptr> + Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return std::abs(elem_operand); + })); + return Status::OK(); + } + + Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override { + return HandleAbs(abs, operand); + }; + + Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], + ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { + return std::ceil(elem_operand); + })); + return Status::OK(); + }; + + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy], + ElementWiseUnaryOp(copy, [](ReturnT elem_operand) { + return elem_operand; + })); + return Status::OK(); + }; + + template + std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { + DCHECK_EQ(src_type, src_literal.shape().element_type()); + return LiteralUtil::Convert< + typename primitive_util::PrimitiveTypeToNative::type, + typename primitive_util::PrimitiveTypeToNative::type>( + src_literal); + } + + Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) override { + auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); + + switch (operand->shape().element_type()) { +#define CONVERT_IF_TYPES_MATCH(src_type) \ + case (src_type): \ + parent_->evaluated_[convert] = LiteralUtil::Convert< \ + typename primitive_util::PrimitiveTypeToNative::type, \ + ReturnT>(operand_literal); \ + break; + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) +#undef CONVERT_IF_TYPES_MATCH + // Other types are not yet supported. + default: + LOG(FATAL) << "unimplemented operand type for HandleCovert: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleExp(HloInstruction* exp, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], + ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { + return std::exp(elem_operand); + })); + return Status::OK(); + }; + + Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ReturnT elem_operand) { + return std::floor(elem_operand); + })); + return Status::OK(); + }; + + Status HandleLog(HloInstruction* log, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], + ElementWiseUnaryOp(log, [](ReturnT elem_operand) { + return std::log(elem_operand); + })); + return Status::OK(); + }; + + Status HandleLogicalNot(HloInstruction* logical_not, + HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logical_not], + ElementWiseUnaryOp(logical_not, + [](ReturnT elem_operand) { return !elem_operand; })); + return Status::OK(); + }; + + Status HandleNegate(HloInstruction* negate, + HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { + return -elem_operand; + })); + return Status::OK(); + }; + + Status HandleSign(HloInstruction* sign, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { + return (ReturnT(0) < elem_operand) - + (elem_operand < ReturnT(0)); + })); + return Status::OK(); + }; + + Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], + ElementWiseUnaryOp(tanh, [](ReturnT elem_operand) { + return std::tanh(elem_operand); + })); + return Status::OK(); + }; + + Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem * rhs_elem; + })); + return Status::OK(); + }; + + Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[subtract], + ElementWiseBinaryOp(subtract, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem - rhs_elem; + })); + return Status::OK(); + }; + + Status HandleAdd(HloInstruction* add, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem + rhs_elem; + })); + return Status::OK(); + }; + + Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem / rhs_elem; + })); + return Status::OK(); + }; + + Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { + return std::fmax(lhs, rhs); + })); + return Status::OK(); + }; + + Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { + return std::fmin(lhs_el, rhs_el); + })); + return Status::OK(); + }; + + Status HandlePower(HloInstruction* power, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ReturnT lhs_el, ReturnT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); + return Status::OK(); + }; + + Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { + return std::fmod(lhs_el, rhs_el); + })); + return Status::OK(); + }; + + Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logical_and], + ElementWiseBinaryOp(logical_and, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el && rhs_el; + })); + return Status::OK(); + }; + + Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logical_or], + ElementWiseBinaryOp(logical_or, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el || rhs_el; + })); + return Status::OK(); + }; + + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, + HloInstruction* arg, HloInstruction* max) override { + std::function clamp_op = + [](ReturnT low, ReturnT high, ReturnT value) { + return std::fmax(low, std::fmin(value, high)); + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], + ElementWiseTernaryOp(clamp, std::move(clamp_op))); + return Status::OK(); + }; + + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override { + CHECK(!ShapeUtil::IsTuple(select->shape())); + std::function select_op = + [](bool pred, ReturnT on_true, ReturnT on_false) { + if (pred) { + return on_true; + } + return on_false; + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], + ElementWiseTernaryOp(select, std::move(select_op))); + return Status::OK(); + }; + + Status Preprocess(HloInstruction* hlo) override { + VLOG(2) << hlo->ToString(); + return Status::OK(); + }; + + private: + StatusOr> ElementWiseUnaryOp( + HloInstruction* instruction, + const std::function& unary_op) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(instruction->operand(0)); + return ElementWiseUnaryOpImpl(instruction, unary_op, + operand_literal); + } + + StatusOr> ElementWiseBinaryOp( + HloInstruction* instruction, + const std::function& binary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = LiteralUtil::CreateFromShape(shape); + + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return binary_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index)); + })); + return std::move(result); + } + + template + StatusOr> ElementWiseTernaryOp( + HloInstruction* instruction, + const std::function& ternary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + const auto* ehs = instruction->operand(2); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str(), + ShapeUtil::HumanString(ehs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); + + auto result = LiteralUtil::CreateFromShape(shape); + + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op( + LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index), + LiteralUtil::Get(ehs_literal, multi_index)); + })); + + return std::move(result); + } + + HloEvaluator* parent_; +}; + +HloEvaluator::HloEvaluator() { + typed_visitors_[PRED] = MakeUnique>(this); + typed_visitors_[U8] = MakeUnique>(this); + typed_visitors_[U16] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: U16."); + }); + typed_visitors_[U32] = MakeUnique>(this); + typed_visitors_[U64] = MakeUnique>(this); + typed_visitors_[S8] = MakeUnique>(this); + typed_visitors_[S16] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: S16."); + }); + typed_visitors_[S32] = MakeUnique>(this); + typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[F16] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: F16."); + }); + typed_visitors_[F32] = MakeUnique>(this); + typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: TUPLE."); + }); + typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: OPAQUE."); + }); +} + +StatusOr> HloEvaluator::Evaluate( + HloComputation* computation, + tensorflow::gtl::ArraySlice args) { + arg_literals_ = args; + evaluated_.clear(); + + TF_RETURN_IF_ERROR(computation->Accept(this)); + return MakeUnique( + GetEvaluatedLiteralFor(computation->root_instruction())); +} + +StatusOr> HloEvaluator::Evaluate( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice operands) { + TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); + + arg_literals_ = operands; + evaluated_.clear(); + + // Evaluate operands of Parameter type against the input literals which + // caches the evaluated literal results. + for (const auto operand : instruction->operands()) { + if (operand->opcode() == HloOpcode::kParameter) { + const Literal* input_literal = arg_literals_[operand->parameter_number()]; + VLOG(2) << "Parameter operand evaluated to: " + << LiteralUtil::ToString(*input_literal); + TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); + + evaluated_[operand] = MakeUnique(*input_literal); + } + } + + TF_RETURN_IF_ERROR(instruction->Visit(this)); + return MakeUnique(GetEvaluatedLiteralFor(instruction)); +} + +StatusOr> HloEvaluator::Evaluate( + HloInstruction* instruction) { + TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction)); + TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); + + arg_literals_.clear(); + evaluated_.clear(); + TF_RETURN_IF_ERROR(instruction->Visit(this)); + return MakeUnique(GetEvaluatedLiteralFor(instruction)); +} + +std::unique_ptr HloEvaluator::TryEvaluate( + HloInstruction* instruction) { + auto result_or = Evaluate(instruction); + if (!result_or.ok()) { + VLOG(1) << "TryEvaluate failed:" << result_or.status(); + return nullptr; + } + + return result_or.ConsumeValueOrDie(); +} + +Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + VLOG(2) << "HandleParameter: " << parameter->ToString(); + const Literal* input_literal = arg_literals_[parameter->parameter_number()]; + VLOG(2) << "Parameter evaluated to: " + << LiteralUtil::ToString(*input_literal); + DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); + + evaluated_[parameter] = MakeUnique(*input_literal); + return Status::OK(); +} + +Status HloEvaluator::HandleConstant(HloInstruction* constant, + const Literal& literal) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + return Status::OK(); +} + +Status HloEvaluator::HandleReshape(HloInstruction* reshape) { + TF_ASSIGN_OR_RETURN( + evaluated_[reshape], + LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)), + AsInt64Slice(reshape->shape().dimensions()))); + return Status::OK(); +} + +Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { + evaluated_[transpose] = LiteralUtil::Transpose( + GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions()); + return Status::OK(); +} + +Status HloEvaluator::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + // The result concatenate dimension is going to be the sum of all concatenate + // dimensions of the operands taking part of the operation. + const Shape& reference_shape = operands[0]->shape(); + CHECK(!ShapeUtil::IsTuple(reference_shape)); + const int64 rank = ShapeUtil::Rank(reference_shape); + const int64 concat_dim = concatenate->dimensions()[0]; + CHECK_GE(concat_dim, 0); + CHECK_LT(concat_dim, rank); + + DimensionVector concat_dimensions(reference_shape.dimensions().begin(), + reference_shape.dimensions().end()); + + for (int64 i = 1; i < operands.size(); ++i) { + const Shape& operand_shape = operands[i]->shape(); + CHECK(!ShapeUtil::IsTuple(operand_shape)); + // Accumulate the concat dimension from all tensors taking part to the + // operation. + concat_dimensions[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + auto result_literal = LiteralUtil::CreateFromDimensions( + reference_shape.element_type(), concat_dimensions); + DimensionVector source_indices(rank, 0); + DimensionVector dest_indices(concat_dimensions.size(), 0); + + for (auto operand : operands) { + const Shape& operand_shape = operand->shape(); + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(), + dest_indices, AsInt64Slice(operand_shape.dimensions()))); + dest_indices[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + evaluated_[concatenate] = std::move(result_literal); + return Status::OK(); +} + +Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) { + if (!ShapeUtil::ElementIsFloating(operand->shape())) { + return InvalidArgument( + "expected element type in shape to be float for IsFinite op, got: %s", + PrimitiveType_Name(operand->shape().element_type()).c_str()); + } + + switch (operand->shape().element_type()) { + case F16: + return Unimplemented("unhandled primitive type: F16."); + case F32: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](float elem_operand) { return std::isfinite(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + case F64: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](double elem_operand) { return std::isfinite(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "unknown/unhandled primitive type."; + } + + return Status::OK(); +} + +Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) { + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s", + ShapeUtil::HumanString(compare->shape()).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); + + const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs); + + // Note here we switch on the operand's type. + switch (lhs->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U8: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U16: + return Unimplemented("unhandled primitive type: U16."); + case U32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S8: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S16: + return Unimplemented("unhandled primitive type: S16."); + case S32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case F16: + return Unimplemented("unhandled primitive type: F16."); + case F32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case F64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + default: + LOG(FATAL) << "unknown primitive type."; + } + + return Status::OK(); +} + +Status HloEvaluator::HandleSlice(HloInstruction* slice, + HloInstruction* operand) { + const Shape& shape = slice->shape(); + auto literal = LiteralUtil::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + + DimensionVector dest_indices(slice->slice_starts().size(), 0); + + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(), + dest_indices, AsInt64Slice(shape.dimensions()))); + + evaluated_[slice] = std::move(literal); + return Status::OK(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h new file mode 100644 index 00000000000..91fd56f54c5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -0,0 +1,158 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ + +#include + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Responsible for evaluating HLO and obtain literal as the evaluation results. +// +// This class is not thread-safe. +class HloEvaluator : public DfsHloVisitorWithDefault { + public: + HloEvaluator(); + // Evaluates a HLO computation and an array of pointers to literals. + // Return the evaluated result as literal if successful. + // Precondition: argument literals are corresponds to the input computation's + // parameters in their post-ordering. For e.g., consider the following graph: + // + // * + // / \ + // + Parameter1 + // / \ + // / \ + // Parameter0 Constant + // + // The input literals array will have its first literal map to Parameter0 and + // the second map to Parameter1. + StatusOr> Evaluate( + HloComputation* computation, + tensorflow::gtl::ArraySlice arg_literals); + + // Evaluates a single HLO instruction and an array of pointers to literals. + // Return the evaluated result as literal if successful. + // Precondition: + // 1. argument literals correspond to the input instruction's parameters in + // their post-ordering. + // 2. the instruction's operands must be of either Parameter or Constant type. + // TODO(b/35950897): implement more ops other than element-wise ops. + StatusOr> Evaluate( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice arg_literals); + + // Evaluates a single HLO instruction with constant operands. + // Returns the evaluated result as literal if successful. + // Precondition: + // 1. all operands of the input instruction are constants. + // 2. the instruction is not a Parameter operation. + StatusOr> Evaluate(HloInstruction* instruction); + + // Same as Evaluate, except returning nullptr on error. + std::unique_ptr TryEvaluate(HloInstruction* instruction); + + protected: + // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting + // literal type of each evaluated Handle* method of a TypedVisitor. + // There are however a few notable exceptions to this is rule, notably: + // - HandleCompare and HandleIsFinite: where the resulting literal type is + // always boolean. + // These operations are handled outside of the parent HloEvaluator handlers + // instead of from within TypedVisitor. + template + class TypedVisitor; + + // Wraps around instruction handling to infer types before dispatching to + // the corresponding typed Visitor. + Status DefaultAction(HloInstruction* hlo) override { + return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); + } + + // Operations that are type-agnostic. + // + Status HandleParameter(HloInstruction* parameter) override; + + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override; + + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; + + Status HandleReshape(HloInstruction* reshape) override; + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + + Status HandleTranspose(HloInstruction* transpose) override; + + Status HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) override; + + Status HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) override; + + private: + // Returns the already-evaluated literal result for the instruction. + // A Constant instruction is considered evaluated and its literal will be + // returned directly without looking up the cache. + // Crash with log if the given instruction has not been evaluated previously. + const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { + if (hlo->IsConstant()) { + return hlo->literal(); + } + auto it = evaluated_.find(hlo); + CHECK(it != evaluated_.end()) + << "could not find evaluated value for: " << hlo->ToString(); + return *(it->second); + } + + // Map from a primitive type to its associated (templated) DfsHloVisitor. + // Note: the hash function here is only needed because current gcc std::hash + // does not specialize for enum types. This should however be fixed in the + // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 + tensorflow::gtl::FlatMap, + std::hash> + typed_visitors_; + + // Tracks the HLO instruction and its evaluated literal result. + // TODO(b/35950897): have better memory management here to free instructions + // that are no longer a parent for any other subsequent instruction in + // post-orderring. + tensorflow::gtl::FlatMap> + evaluated_; + + // Stores input literals, assuming they are in post-order. Literals are not + // owned by this class, and they must outlive the lifetime of the instance of + // this class. + tensorflow::gtl::ArraySlice arg_literals_; + + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc new file mode 100644 index 00000000000..b26ece28b75 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloEvaluatorTest : public ::testing::Test { + protected: + HloEvaluatorTest() { evaluator_ = MakeUnique(); } + + std::unique_ptr evaluator_; +}; + +// Verifies that HloEvaluator evaluates a HLO instruction that performs clamp +// with 3 operands. +TEST_F(HloEvaluatorTest, DoesClamp) { + auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); + auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + + Shape shape = low->shape(); + auto c1 = HloInstruction::CreateConstant(std::move(low)); + auto c2 = HloInstruction::CreateConstant(std::move(high)); + auto c3 = HloInstruction::CreateConstant(std::move(value)); + auto instruction = HloInstruction::CreateTernary( + shape, HloOpcode::kClamp, c1.get(), c2.get(), c3.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs select +// with 3 operands. +TEST_F(HloEvaluatorTest, DoesSelect) { + auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); + auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + + Shape shape = on_true->shape(); + auto c1 = HloInstruction::CreateConstant(std::move(pred)); + auto c2 = HloInstruction::CreateConstant(std::move(on_true)); + auto c3 = HloInstruction::CreateConstant(std::move(on_false)); + auto instruction = HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, c1.get(), c2.get(), c3.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise addition with 2 operands. +TEST_F(HloEvaluatorTest, DoesAdd) { + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + + Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + auto c1 = HloInstruction::CreateConstant(std::move(lhs)); + auto c2 = HloInstruction::CreateConstant(std::move(rhs)); + auto instruction = + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1.get(), c2.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise divide with 2 operands. +TEST_F(HloEvaluatorTest, DoesDivide) { + auto lhs_s64 = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs_s64 = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + + Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); + auto c1_s64 = HloInstruction::CreateConstant(std::move(lhs_s64)); + auto c2_s64 = HloInstruction::CreateConstant(std::move(rhs_s64)); + auto instruction = HloInstruction::CreateBinary(shape_s64, HloOpcode::kDivide, + c1_s64.get(), c2_s64.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + + auto lhs_f64 = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs_f64 = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + + Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); + auto c1_f64 = HloInstruction::CreateConstant(std::move(lhs_f64)); + auto c2_f64 = HloInstruction::CreateConstant(std::move(rhs_f64)); + instruction = HloInstruction::CreateBinary(shape_f64, HloOpcode::kDivide, + c1_f64.get(), c2_f64.get()); + + result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + expected = + LiteralUtil::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise abs op with 1 operand. +TEST_F(HloEvaluatorTest, DoesAbs) { + auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); + auto c1 = HloInstruction::CreateConstant(std::move(operand)); + auto instruction = + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + + // For R0 literal. + const Shape& r0 = ShapeUtil::MakeShape(F32, {}); + operand = LiteralUtil::CreateR0(-1.0f); + c1 = HloInstruction::CreateConstant(std::move(operand)); + instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get()); + result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); + expected = LiteralUtil::CreateR0(1.0f); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + + // For R1 literal with dimension of size 0. + Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); + operand = LiteralUtil::CreateR1({}); + c1 = HloInstruction::CreateConstant(std::move(operand)); + instruction = + HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get()); + + result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); + expected = LiteralUtil::CreateR1({}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} // namespace + +// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor +// constant operands. +TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { + HloComputation::Builder builder( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; + + Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + + auto param_lhs = HloInstruction::CreateParameter(0, shape, "lhs"); + auto param_rhs = HloInstruction::CreateParameter(1, shape, "rhs"); + auto lhs_instruction = HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, param_lhs.get(), param_rhs.get()); + + auto param_rhs2 = HloInstruction::CreateParameter(2, shape, "rhs2"); + auto root_instruction = HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, lhs_instruction.get(), param_rhs2.get()); + + builder.AddInstruction(std::move(root_instruction)); + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies Reshape operation is correctly evaluated. +TEST_F(HloEvaluatorTest, DoesReshape) { + HloComputation::Builder builder( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + + const int64 dimensions[] = {11, 8, 7, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = LiteralUtil::CloneToUnique(*literal); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); + const int64 permutation[] = {1, 2, 0, 4, 3}; + builder.AddInstruction( + HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + LiteralUtil::EachCell( + *result, [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + EXPECT_TRUE(value == + LiteralUtil::Get(*literal_clone, rindexes)); + }); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index 0b87b04fc4b..9e25f1aceb1 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/metric_table_report.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -32,6 +33,7 @@ namespace xla { void HloExecutionProfile::AddProfileResult(const HloInstruction* hlo, uint64 cycles_taken) { hlo_to_cycles_taken_[hlo] = cycles_taken; + profiled_computations_.insert(hlo->parent()); } uint64 HloExecutionProfile::GetProfileResult(const HloInstruction& hlo) const { @@ -43,63 +45,104 @@ uint64 HloExecutionProfile::GetProfileResult(const HloInstruction& hlo) const { } string HloExecutionProfile::ToString( + const HloComputation& computation, const DeviceDescription& device_description, - const HloCostAnalysis& cost_analysis) const { + const HloCostAnalysis::ShapeSizeFunction& shape_size) const { + HloCostAnalysis cost_analysis(shape_size); + tensorflow::Status analysis_status = + computation.root_instruction()->Accept(&cost_analysis); + if (!analysis_status.ok()) { + return ""; + } + using Item = std::pair; - std::vector items(hlo_to_cycles_taken_.begin(), - hlo_to_cycles_taken_.end()); + std::vector items; + for (Item item : hlo_to_cycles_taken_) { + // Only include the HLOs which are part of the desired computation. + if (item.first->parent() == &computation) { + items.push_back(item); + } + } auto custom_less = [](const Item& lhs, const Item& rhs) { return lhs.second > rhs.second; }; std::sort(items.begin(), items.end(), custom_less); string result; - const int64 total_cycles = total_cycles_executed(); + const int64 total_cycles = total_cycles_executed(computation); double clock_rate_ghz = device_description.clock_rate_ghz(); + CHECK_GE(clock_rate_ghz, 1e-9); const auto cycles_to_microseconds = [&](double cycles) { return cycles / clock_rate_ghz / 1000.0; }; - auto append_item = [&](int64 cycles, int64 flops, const string& name) { + auto append_item = [&](int64 cycles, int64 flops, int64 bytes_accessed, + const string& name) { double nsecs = cycles / clock_rate_ghz; + string bytes_per_sec; + string bytes_per_cycle; + if (cycles <= 0 || bytes_accessed < 0) { + bytes_per_sec = ""; + bytes_per_cycle = ""; + } else { + bytes_per_sec = tensorflow::strings::HumanReadableNumBytes( + bytes_accessed / (nsecs / 1e9)); + bytes_per_cycle = + tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles); + } + + double cycles_percent = 0; + if (total_cycles > 0) { + cycles_percent = cycles / static_cast(total_cycles) * 100; + } + tensorflow::strings::StrAppend( &result, tensorflow::strings::Printf( - "%15lld cycles (%6.2f%%) :: %12.1f usec @ f_nom :: %18s :: %s", - cycles, cycles / static_cast(total_cycles) * 100, - cycles_to_microseconds(cycles), + "%15lld cycles (%6.2f%%) :: %12.1f usec @ f_nom :: %18s :: %12s/s " + ":: " + "%12s/cycle :: " + "%s", + cycles, cycles_percent, cycles_to_microseconds(cycles), flops <= 0 ? "" : HumanReadableNumFlops(flops, nsecs).c_str(), - name.c_str())); + bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str())); }; tensorflow::strings::StrAppend( - &result, - tensorflow::strings::Printf("HLO execution profile: (%s @ f_nom)\n\t", - tensorflow::strings::HumanReadableElapsedTime( - total_cycles / clock_rate_ghz / 1e9) - .c_str())); - append_item(total_cycles, -1, "[total]"); + &result, tensorflow::strings::Printf( + "HLO execution profile for %s: (%s @ f_nom)\n\t", + computation.name().c_str(), + tensorflow::strings::HumanReadableElapsedTime( + total_cycles / clock_rate_ghz / 1e9) + .c_str())); + + append_item(total_cycles, -1, -1, "[total]"); for (const auto& item : items) { + const HloInstruction* hlo = item.first; tensorflow::strings::StrAppend(&result, "\n\t"); - auto flops = item.first == nullptr - ? -1 - : cost_analysis.hlo_to_flop_count(*item.first); - string display = item.first == nullptr ? "" : item.first->ToString(); - append_item(item.second, flops, display); + const int64 flops = (hlo == nullptr) ? -1 : cost_analysis.flop_count(*hlo); + const int64 bytes_accessed = + (hlo == nullptr) ? -1 : cost_analysis.bytes_accessed(*hlo); + const string display = (hlo == nullptr) ? "" : hlo->ToString(); + append_item(item.second, flops, bytes_accessed, display); } - MetricTableReport table; - table.SetMetricName("microseconds"); - table.SetEntryName("ops"); - table.SetShowCategoryTable(); - for (const auto& item : items) { - MetricTableReport::Entry entry; - entry.text = item.first->ToString(); - entry.short_text = item.first->ToString(/*compact_operands=*/true); - entry.category_text = item.first->ToCategory(); - entry.metric = cycles_to_microseconds(item.second); - table.AddEntry(std::move(entry)); + if (total_cycles <= 0) { + result += "****** 0 total cycles ******\n"; + } else { + MetricTableReport table; + table.SetMetricName("microseconds"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& item : items) { + MetricTableReport::Entry entry; + entry.text = item.first->ToString(); + entry.short_text = item.first->ToString(/*compact_operands=*/true); + entry.category_text = item.first->ToCategory(); + entry.metric = cycles_to_microseconds(item.second); + table.AddEntry(std::move(entry)); + } + result += table.MakeReport(cycles_to_microseconds(total_cycles)); } - result += table.MakeReport(cycles_to_microseconds(total_cycles)); return result; } diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index 6cc20798139..70b94a3f950 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -43,27 +43,45 @@ class HloExecutionProfile { uint64 GetProfileResult(const HloInstruction& hlo) const; // Return the number of cycles this computation took to execute. - uint64 total_cycles_executed() const { return total_cycles_executed_; } + uint64 total_cycles_executed(const HloComputation& computation) const { + auto it = total_cycles_executed_.find(&computation); + if (it != total_cycles_executed_.end()) { + return it->second; + } + return 0; + } - // Record how many cycles the entire computation took to execute. - void set_total_cycles_executed(uint64 total_cycles_executed) { - total_cycles_executed_ = total_cycles_executed; + // Record how many cycles a computation took to execute. + void set_total_cycles_executed(const HloComputation& computation, + uint64 total_cycles_executed) { + total_cycles_executed_[&computation] = total_cycles_executed; } // Returns a version of the execution profile suitable for performance // debugging; e.g. emits cycle counts, execution time at the nominal device // frequency, and the effective throughput given the provided cost_analysis - // for the operations. - string ToString(const DeviceDescription& device_description, - const HloCostAnalysis& cost_analysis) const; + // for the operations in a given computation. + // Returns an empty string if it wasn't possible to generate a printable + // version. + string ToString(const HloComputation& computation, + const DeviceDescription& device_description, + const HloCostAnalysis::ShapeSizeFunction& shape_size) const; + + // Returns the computations we have profiled. + std::unordered_set profiled_computations() const { + return profiled_computations_; + } private: // Contains a mapping from HLO to the number of cycles it took to execute it. std::unordered_map hlo_to_cycles_taken_; - // If non-empty, contains the total number of cycles this computation took to + // If non-empty, contains the total number of cycles a computation took to // execute. - uint64 total_cycles_executed_ = 0; + std::unordered_map total_cycles_executed_; + + // The computations we have profiled. + std::unordered_set profiled_computations_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 631e784755d..eb2e5dfb37f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -15,14 +15,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include #include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -45,6 +48,73 @@ namespace xla { namespace hlo_graph_dumper { namespace { +// Node color schemes, used by NodeColorAttributes. +enum ColorScheme { + kBlue, + kBrown, + kDarkBlue, + kDarkGreen, + kDarkRed, + kGray, + kGreen, + kOrange, + kPurple, + kRed, + kWhite, + kYellow, +}; + +// Given a ColorScheme, returns an attribute string for a node of that color. +// Sets the node's fill, stroke, and text colors. +// +// Colors are from https://material.io/color. +string NodeColorAttributes(ColorScheme color) { + using std::make_tuple; + + const char *fill_color, *stroke_color, *font_color; + std::tie(fill_color, stroke_color, font_color) = + [color]() -> std::tuple { + switch (color) { + case kBlue: + return make_tuple("#bbdefb", "#8aacc8", "black"); + case kBrown: + return make_tuple("#bcaaa4", "#8c7b75", "black"); + case kDarkBlue: + return make_tuple("#1565c0", "#003c8f", "white"); + case kDarkGreen: + return make_tuple("#2e7d32", "#005005", "white"); + case kDarkRed: + return make_tuple("#b71c1c", "#7f0000", "white"); + case kGray: + return make_tuple("#cfd8dc", "#9ea7aa", "black"); + case kGreen: + return make_tuple("#c8e6c9", "#97b498", "black"); + case kOrange: + return make_tuple("#ffe0b2", "#cbae82", "black"); + case kPurple: + return make_tuple("#e1bee7", "#af8eb5", "black"); + case kRed: + return make_tuple("#ffcdd2", "#cb9ca1", "black"); + case kWhite: + return make_tuple("white", "black", "black"); + case kYellow: + return make_tuple("#fff9c4", "#cbc693", "black"); + } + }(); + + return Printf( + "style=filled, fontcolor=\"%s\", color=\"%s\", fillcolor=\"%s\"", + font_color, stroke_color, fill_color); +} + +// Replaces <> with <>, so that this string is safe(er) for use in a +// graphviz HTML-like string. +string HtmlLikeStringSanitize(tensorflow::StringPiece s) { + return tensorflow::str_util::StringReplace( + tensorflow::str_util::StringReplace(s, "<", "<", /*replace_all=*/true), + ">", ">", /*replace_all=*/true); +} + // Returns the dot graph identifier for the given instruction. string InstructionId(const HloInstruction* instruction) { return Printf("%lld", reinterpret_cast(instruction)); @@ -55,68 +125,6 @@ string ComputationId(const HloComputation* computation) { return Printf("%lld", reinterpret_cast(computation)); } -// Returns a compact string that represents the convolution dimension numbers. -string ConvolutionDimensionNumbersToString( - const ConvolutionDimensionNumbers& dim_numbers) { - return Printf("B@%lld,Z@%lld,KIZ@%lld,KOZ@%lld", - dim_numbers.batch_dimension(), dim_numbers.feature_dimension(), - dim_numbers.kernel_input_feature_dimension(), - dim_numbers.kernel_output_feature_dimension()); -} - -// Returns a compact string that represents the non-trivial fields in the window -// description. If there are no non-trivial fields, the empty string is -// returned. -string WindowToString(const Window& window) { - bool display_padding = false; - bool display_window_dilation = false; - bool display_base_dilation = false; - bool display_stride = false; - for (const WindowDimension& dimension : window.dimensions()) { - display_padding |= - dimension.padding_low() != 0 || dimension.padding_high() != 0; - display_window_dilation |= dimension.window_dilation() != 1; - display_base_dilation |= dimension.base_dilation() != 1; - display_stride |= dimension.stride() != 1; - } - std::vector pieces = {}; - if (display_padding) { - pieces.push_back("\\n"); - pieces.push_back("padding=["); - for (const WindowDimension& dimension : window.dimensions()) { - pieces.push_back(StrCat("(", dimension.padding_low(), ",", - dimension.padding_high(), ")")); - pieces.push_back(", "); - } - pieces.pop_back(); - pieces.push_back("]"); - } - // Make a convenient lambda that adds a simple int64 field in each - // WindowDimension. - auto add_field = [&pieces, &window]( - const string& label, - tensorflow::protobuf_int64 (WindowDimension::*member)() const) { - pieces.push_back("\\n"); - pieces.push_back(label + "=["); - for (const WindowDimension& dimension : window.dimensions()) { - pieces.push_back(StrCat(((&dimension)->*member)())); - pieces.push_back(", "); - } - pieces.pop_back(); - pieces.push_back("]"); - }; - if (display_window_dilation) { - add_field("window_dilation", &WindowDimension::window_dilation); - } - if (display_base_dilation) { - add_field("base_dilation", &WindowDimension::base_dilation); - } - if (display_stride) { - add_field("stride", &WindowDimension::stride); - } - return Join(pieces, ""); -} - // Returns the dot graph edges and nodes for the given instruction sequence. // Edges which extend between computations are added to the vector // intercomputation_edges. This is necessary because graphviz does not render @@ -135,7 +143,8 @@ string InstructionSequenceGraph( std::vector param_instructions; for (auto& instruction : instructions) { if (instruction->opcode() == HloOpcode::kParameter) { - int64 param_number = instruction->parameter_number(); + size_t param_number = instruction->parameter_number(); + if (param_instructions.size() < param_number + 1) { param_instructions.resize(param_number + 1, nullptr); } @@ -160,25 +169,38 @@ string InstructionSequenceGraph( param_ports.push_back( Printf("<%s> %s", InstructionId(param).c_str(), label.c_str())); } - StrAppend(&graph_body, param_node_name, - " [shape=record,style=filled,fillcolor=\"lightblue1\",", - "label=\"{parameters | {", Join(param_ports, "|"), "}}\"];\n"); + // (If we wanted the word "parameters" to be bold like the other op names, + // we'd have to make this into an HTML-like table. It is possible but + // complicated; see http://www.graphviz.org/doc/info/shapes.html#html.) + StrAppend(&graph_body, param_node_name, " [shape=record ", + NodeColorAttributes(kOrange), "label=\"{parameters | {", + Join(param_ports, "|"), "}}\"];\n"); } for (auto& instruction : instructions) { - string color = "peachpuff"; - string shape = "ellipse"; - string name = HloOpcodeString(instruction->opcode()); - if (HloOpcode::kFusion == instruction->opcode()) { - name += ": " + FusionKindString(instruction->fusion_kind()); - } + ColorScheme color = kYellow; + string shape = "box"; + string name = + StrCat("", HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()), + " ", HtmlLikeStringSanitize(instruction->name())); if (HloOpcode::kConvolution == instruction->opcode()) { - name += ":\\n" + ConvolutionDimensionNumbersToString( - instruction->convolution_dimension_numbers()) + - WindowToString(instruction->window()); + StrAppend( + &name, "
", + HtmlLikeStringSanitize( + instruction->ConvolutionDimensionNumbersToString()), + "
", + HtmlLikeStringSanitize(window_util::ToString(instruction->window()))); + } + + if (!instruction->metadata().op_name().empty()) { + StrAppend(&name, "
", + HtmlLikeStringSanitize(instruction->metadata().op_name())); + } + if (!instruction->metadata().source_file().empty() && + instruction->metadata().source_line() != 0) { + StrAppend(&name, "
", instruction->metadata().source_file(), ":", + instruction->metadata().source_line()); } - name += "\\n" + instruction->name(); - std::vector called_computations; // Pick different colors or shapes for instructions which are particularly // expensive (eg, dot) and those which are unusual in some way or unique @@ -191,17 +213,15 @@ string InstructionSequenceGraph( case HloOpcode::kAdd: case HloOpcode::kCeil: case HloOpcode::kClamp: - case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kDivide: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: case HloOpcode::kIndex: + case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLogicalAnd: @@ -213,64 +233,49 @@ string InstructionSequenceGraph( case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: - case HloOpcode::kPad: case HloOpcode::kPower: case HloOpcode::kRemainder: - case HloOpcode::kReshape: - case HloOpcode::kReverse: case HloOpcode::kSelect: case HloOpcode::kSign: case HloOpcode::kSlice: case HloOpcode::kSort: case HloOpcode::kSubtract: case HloOpcode::kTanh: - case HloOpcode::kTuple: - case HloOpcode::kUpdate: - break; - - case HloOpcode::kBroadcast: - case HloOpcode::kTranspose: - StrAppend(&name, "\\n", "dims={", Join(instruction->dimensions(), ","), - "}"); - break; - case HloOpcode::kGetTupleElement: - StrAppend(&name, "\\nindex=", instruction->tuple_index()); break; case HloOpcode::kRng: - StrAppend(&name, "\\n", + StrAppend(&name, "
", RandomDistribution_Name(instruction->random_distribution())); break; - case HloOpcode::kConstant: - shape = "boxed"; - color = "palegreen"; - if (ShapeUtil::IsScalar(instruction->shape())) { - StrAppend(&name, "\\n", "value=", LiteralUtil::GetAsString( - instruction->literal(), {})); - } + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + StrAppend(&name, "
", "dims={", + Join(instruction->dimensions(), ","), "}"); break; case HloOpcode::kBitcast: + case HloOpcode::kTuple: + case HloOpcode::kTrace: + color = kWhite; + break; + case HloOpcode::kGetTupleElement: + color = kWhite; + StrAppend(&name, "
index=", instruction->tuple_index()); + break; + case HloOpcode::kConcatenate: case HloOpcode::kCopy: - color = "white"; + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kUpdate: + color = kGreen; break; - case HloOpcode::kCall: - color = "tomato"; - break; - case HloOpcode::kCustomCall: - color = "tomato4"; - StrAppend(&name, "\\n", - "custom_call_target=", instruction->custom_call_target()); + case HloOpcode::kConstant: + color = kBlue; break; + case HloOpcode::kConvolution: case HloOpcode::kDot: - color = "slateblue"; - break; - case HloOpcode::kSend: - color = "purple"; - break; - case HloOpcode::kRecv: - color = "orange"; - break; - case HloOpcode::kMap: - color = "palevioletred"; + color = kDarkBlue; break; case HloOpcode::kParameter: // A single record node is created for all the parameter nodes with a @@ -279,38 +284,54 @@ string InstructionSequenceGraph( continue; case HloOpcode::kReduce: StrAppend(&name, " dims=", Join(instruction->dimensions(), ",")); - color = "lightsalmon"; + color = kPurple; break; case HloOpcode::kSelectAndScatter: case HloOpcode::kReduceWindow: - color = "lightsalmon"; - break; - case HloOpcode::kTrace: - color = "white"; + color = kPurple; break; case HloOpcode::kWhile: - color = "forestgreen"; + shape = "ellipse"; + color = kDarkGreen; break; + case HloOpcode::kMap: case HloOpcode::kFusion: - color = "gray"; - break; - case HloOpcode::kConvolution: - color = "red"; - break; - case HloOpcode::kCrossReplicaSum: - color = "turquoise"; + color = kGray; break; + case HloOpcode::kSend: + case HloOpcode::kRecv: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - color = "blue"; + case HloOpcode::kCrossReplicaSum: + color = kBrown; + break; + case HloOpcode::kCall: + color = kDarkGreen; + break; + case HloOpcode::kCustomCall: + color = kDarkGreen; + StrAppend(&name, "
", + "custom_call_target=", instruction->custom_call_target()); break; } // Create instruction node with appropriate label, shape, and color. + // label is interpreted as an HTML-like string, so newlines must be + // delimited with
, rather than \n. string label = - StrCat(name, "\\n", ShapeUtil::HumanString(instruction->shape())); + StrCat(name, "
", ShapeUtil::HumanString(instruction->shape())); + + if (instruction->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instruction->shape())) { + auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( + instruction->shape(), /*linear_index=*/0); + StrAppend(&label, " = {", + LiteralUtil::GetAsString(instruction->literal(), elem_idx), + "}"); + } + if (show_addresses) { - Appendf(&label, "\\n[%p]", instruction.get()); + Appendf(&label, "
[%p]", instruction.get()); } if (show_layouts && LayoutUtil::HasLayout(instruction->shape())) { string layout_string; @@ -322,24 +343,24 @@ string InstructionSequenceGraph( layout_string = Join(instruction->shape().layout().minor_to_major(), ","); } - StrAppend(&label, "\\nlayout={", layout_string, "}"); + StrAppend(&label, "
layout={", layout_string, "}"); } if (hlo_execution_profile != nullptr) { auto hlo_cycles_executed = hlo_execution_profile->GetProfileResult(*instruction); auto total_cycles_executed = - hlo_execution_profile->total_cycles_executed(); + hlo_execution_profile->total_cycles_executed(*instruction->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { - Appendf(&label, "\\n%% of cycles executed=%.2f", + Appendf(&label, "
%% of cycles executed=%.2f", (static_cast(hlo_cycles_executed) / static_cast(total_cycles_executed)) * 100); } } - Appendf(&graph_body, - "%s [label=\"%s\", shape=%s, style=filled, fillcolor=%s];\n", + + Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n", InstructionId(instruction.get()).c_str(), label.c_str(), - shape.c_str(), color.c_str()); + shape.c_str(), NodeColorAttributes(color).c_str()); // Create edges from the instruction's operands to the instruction. int64 operand_number = 0; @@ -369,7 +390,7 @@ string InstructionSequenceGraph( StrCat("cluster_", InstructionId(instruction.get())); StrAppend(&graph_body, "subgraph ", cluster_name, " {\n"); StrAppend(&graph_body, - "label=\"fused expression\";\nstyle=filled;\n" + "label=<fused expression>;\nstyle=\"rounded,filled\";\n" "color=lightgrey;\n"); StrAppend(&graph_body, InstructionSequenceGraph( instruction->fused_instructions(), @@ -385,7 +406,8 @@ string InstructionSequenceGraph( } else { // Add a dotted edge between the instruction and any computations that the // instruction calls. - for (auto* computation : instruction->MakeCalledComputationsSet()) { + for (const HloComputation* computation : + instruction->called_computations()) { string cluster_name = StrCat("cluster_", ComputationId(computation)); string call_edge = Printf( "%s -> %s [ style=dashed; ltail=%s ];\n", @@ -398,19 +420,39 @@ string InstructionSequenceGraph( return graph_body; } +// DOT graphs accept a stylesheet as a URL. So naturally, an inline stylesheet +// is a data URI! +// +// We don't perform any escaping on this string, so be careful not to use double +// quotes inside. +static const char* dot_stylesheet = R"( +data:text/css, +@import url(https://fonts.googleapis.com/css?family=Roboto:400,700); +svg text { + font-family: 'Roboto'; + font-size: 12px; +} +)"; + string ComputationToDotGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, const HloExecutionProfile* hlo_execution_profile) { - string graph_label = StrCat(label, "\\n", computation.name()); + string graph_label = StrCat(label, "
", computation.name()); if (hlo_execution_profile != nullptr) { - auto cycles = hlo_execution_profile->total_cycles_executed(); - Appendf(&graph_label, "\\ntotal cycles = %lld (%s)", cycles, + auto cycles = hlo_execution_profile->total_cycles_executed(computation); + Appendf(&graph_label, "
total cycles = %lld (%s)", cycles, tensorflow::strings::HumanReadableNum(cycles).c_str()); } - string graph = - Printf("digraph G {\nrankdir=TB;\ncompound=true;\nlabel=\"%s\"\n", - graph_label.c_str()); + string graph = Printf( + R"(digraph G { +rankdir=TB; +compound=true; +label=<%s>; +labelloc=t; +stylesheet="%s" +)", + graph_label.c_str(), dot_stylesheet); // Emit embedded computations as subgraph clusters. std::vector intercomputation_edges; @@ -418,7 +460,9 @@ string ComputationToDotGraph(const HloComputation& computation, string graph_body = InstructionSequenceGraph( embedded->instructions(), show_addresses, show_layouts, &intercomputation_edges, hlo_execution_profile); - Appendf(&graph, "subgraph cluster_%s {\nlabel=\"%s\";\n%s}\n", + Appendf(&graph, + "subgraph cluster_%s " + "{\nstyle=rounded;label=<%s>;labelloc=t;\n%s}\n", ComputationId(embedded).c_str(), embedded->name().c_str(), graph_body.c_str()); } @@ -464,14 +508,34 @@ namespace { class FileGraphRenderer : public GraphRendererInterface { public: - string RenderGraph(const string& graph) override { + string RenderGraph(const string& graph, GraphKind graph_kind) override { static std::atomic output_num(0); legacy_flags::HloGraphDumperFlags* flags = legacy_flags::GetHloGraphDumperFlags(); - string path = StrCat(flags->xla_hlo_dump_graph_path, "hlo_graph_", - output_num++, ".dot"); - tensorflow::Status status = - tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph); + string file_extension; + switch (graph_kind) { + case DOT_GRAPH: + file_extension = ".dot"; + break; + case TF_GRAPHDEF: + file_extension = ".pbtxt"; + break; + } + string path = + JoinPath(flags->xla_hlo_dump_graph_path, + StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); + auto status = Status::OK(); + int fd = mkstemps(&path[0], file_extension.length()); + if (fd < 0) { + status = + Status(tensorflow::error::Code::UNKNOWN, + StrCat("Failed to create temporary file to dump HLO graph: ", + strerror(errno))); + } else { + status = tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, + graph); + close(fd); + } if (!status.ok()) { LOG(WARNING) << "Saving HLO graph failed: " << status; } @@ -486,10 +550,26 @@ XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); string DumpGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, const HloExecutionProfile* hlo_execution_profile) { - string graph = ComputationToDotGraph(computation, label, show_addresses, - show_layouts, hlo_execution_profile); - - string graph_url = GetGraphRenderer()->RenderGraph(graph); + string graph; + string graph_url; + legacy_flags::HloGraphDumperFlags* flags = + legacy_flags::GetHloGraphDumperFlags(); + if (flags->xla_hlo_dump_as_graphdef) { + HloTfGraphBuilder builder; + TF_CHECK_OK(builder.AddComputation(computation)); + CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), + &graph)); + // TODO(b/37198616): Use the default registered renderers when all + // renderers support rendering GraphDefs. Always dump GraphDefs to files + // for now. + graph_url = FileGraphRenderer().RenderGraph( + graph, GraphRendererInterface::TF_GRAPHDEF); + } else { + graph = ComputationToDotGraph(computation, label, show_addresses, + show_layouts, hlo_execution_profile); + graph_url = GetGraphRenderer()->RenderGraph( + graph, GraphRendererInterface::DOT_GRAPH); + } LOG(INFO) << "computation " << computation.name() << " [" << label << "]: " << graph_url; return graph_url; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 5f841da1f35..8ed50c38473 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -25,8 +25,25 @@ limitations under the License. namespace xla { namespace hlo_graph_dumper { -// Dumps a graph of the computation to the GraphViz server and returns -// a description of the rendered graph (e.g., a URL). +// Abstract interface for classes that render HLO graphs (e.g. DOT graph, +// tensorflow GraphDef). +class GraphRendererInterface { + public: + enum GraphKind { + DOT_GRAPH, + TF_GRAPHDEF, + }; + + virtual ~GraphRendererInterface() = default; + + // Renders a DOT graph, returning a description of the rendered output + // (e.g., a URL) + virtual string RenderGraph(const string& graph, GraphKind graph_kind) = 0; +}; + +// Dumps a graph of the computation and returns a description of the rendered +// graph (e.g., a URL) based on the renderer. The "best" renderer in the +// registry is used. string DumpGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, const HloExecutionProfile* hlo_execution_profile = nullptr); @@ -40,16 +57,6 @@ string DumpGraph(const HloComputation& computation, const string& label, void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix = true); -// Abstract interface for classes that render DOT graphs. -class GraphRendererInterface { - public: - virtual ~GraphRendererInterface() = default; - - // Renders a DOT graph, returning a description of the rendered output - // (e.g., a URL) - virtual string RenderGraph(const string& graph) = 0; -}; - // Graph renderers may be added using a registration mechanism, e.g.: // XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) // The renderer with the highest numeric priority value is used. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index cd67757bb2c..ea813c98743 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -17,17 +17,19 @@ limitations under the License. #include #include +#include #include #include #include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -42,6 +44,11 @@ limitations under the License. namespace xla { +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { auto instruction = @@ -58,7 +65,7 @@ namespace xla { WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); instruction->literal_.reset(new Literal); - *instruction->literal_->mutable_u8s() += tag; + instruction->literal_->append_u8s(tag); return instruction; } @@ -117,6 +124,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kLogicalNot: case HloOpcode::kNegate: @@ -194,7 +202,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->to_apply_ = map_computation; + instruction->called_computations_.push_back(map_computation); return instruction; } @@ -205,10 +213,10 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape)); if (window_util::HasBaseDilation(window)) { - instruction->set_name(instruction->name() + "-base-dilated"); + instruction->name_ = instruction->name() + "-base-dilated"; } if (window_util::HasWindowDilation(window)) { - instruction->set_name(instruction->name() + "-window-dilated"); + instruction->name_ = instruction->name() + "-window-dilated"; } instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); @@ -235,11 +243,13 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( - HloInstruction* operand, tensorflow::StringPiece outfeed_config) { + const Shape& shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config) { std::unique_ptr instruction = WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil())); instruction->AppendOperand(operand); instruction->outfeed_config_ = outfeed_config.ToString(); + instruction->outfeed_shape_ = shape; return instruction; } @@ -273,19 +283,22 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, HloInstruction* init) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); instruction->AppendOperand(init); - instruction->condition_ = condition; - instruction->body_ = body; + // Body comes before condition computation in the vector. + instruction->called_computations_.push_back(body); + instruction->called_computations_.push_back(condition); return instruction; } /* static */ std::unique_ptr HloInstruction::CreateSlice( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) { + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape)); instruction->AppendOperand(operand); instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); + instruction->slice_strides_.assign(strides.begin(), strides.end()); return instruction; } @@ -342,7 +355,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, instruction->AppendOperand(init_value); instruction->dimensions_.assign(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); - instruction->to_apply_ = reduce_computation; + instruction->called_computations_.push_back(reduce_computation); return instruction; } @@ -353,7 +366,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape)); instruction->AppendOperand(operand); instruction->AppendOperand(init_value); - instruction->to_apply_ = reduce_computation; + instruction->called_computations_.push_back(reduce_computation); instruction->window_ = MakeUnique(window); return instruction; } @@ -368,8 +381,9 @@ HloInstruction::CreateSelectAndScatter( instruction->AppendOperand(operand); instruction->AppendOperand(source); instruction->AppendOperand(init_value); - instruction->select_ = select; - instruction->scatter_ = scatter; + // Select comes before scatter in the vector. + instruction->called_computations_.push_back(select); + instruction->called_computations_.push_back(scatter); instruction->window_ = MakeUnique(window); return instruction; } @@ -398,7 +412,9 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateReshape( const Shape& shape, HloInstruction* operand) { CHECK_EQ(ShapeUtil::ElementsIn(shape), - ShapeUtil::ElementsIn(operand->shape())); + ShapeUtil::ElementsIn(operand->shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(operand->shape()); auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; @@ -423,6 +439,8 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; + instruction->set_parent(fused_root->parent()); + instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); instruction->CheckFusionInstruction(); return instruction; @@ -477,7 +495,7 @@ HloInstruction* HloInstruction::FuseInstruction( CHECK_EQ(opcode_, HloOpcode::kFusion); // This fusion instruction must be a user of instruction_to_fuse. - CHECK_NE(0, instruction_to_fuse->users().count(this)); + CHECK(IsUserOf(instruction_to_fuse)); HloInstruction* fused_instruction = CloneAndFuseInternal(instruction_to_fuse); CheckFusionInstruction(); return fused_instruction; @@ -488,14 +506,20 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(instruction_to_fuse->IsFusable()); - bool new_fusion_instruction = fused_instructions_.empty(); - fused_instructions_.emplace_back(instruction_to_fuse->Clone()); - HloInstruction* clone = fused_instructions_.back().get(); - clone->parent_fusion_instruction_ = this; - - if (new_fusion_instruction) { - fused_root_ = clone; + HloInstruction* clone = nullptr; + if (fused_instructions_computation_ == nullptr) { + // New fusion instruction. + auto builder = HloComputation::Builder("fused_computation", true); + builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); + fused_instructions_computation_ = builder.Build(); + clone = fused_expression_root(); + clone->parent_fusion_instruction_ = this; } else { + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + clone = fused_instructions_computation_->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + clone->parent_fusion_instruction_ = this; // instruction_to_fuse is necessarily an operand of the fusion instruction. // After fusion this will no longer be the case. Remove the operand from the // operand list and remove its corresponding fused parameter @@ -503,6 +527,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // consistent with their index in the fused_parameter_ vector. CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) != operands_.end()); + const std::vector& fused_parameters_ = + fused_instructions_computation_->parameter_instructions(); for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { if (instruction_to_fuse == operands_[operand_num]) { // replace the fused parameter instruction's uses with the clone. @@ -511,22 +537,9 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Remove the corresponding fused parameter and operand from their // respective vectors. - fused_parameters_.erase(fused_parameters_.begin() + operand_num); + TF_CHECK_OK( + fused_instructions_computation_->RemoveParameter(operand_num)); operands_.erase(operands_.begin() + operand_num); - - // Renumber fused parameter numbers to match the vector index. - while (operand_num < fused_parameters_.size()) { - fused_parameters_[operand_num]->parameter_number_ = operand_num; - operand_num++; - } - // Throw removed fused parameter instruction away. - auto inst_it = - std::find_if(fused_instructions_.begin(), fused_instructions_.end(), - [=](const std::unique_ptr& inst) { - return inst.get() == fused_parameter; - }); - CHECK(inst_it != fused_instructions_.end()); - fused_instructions_.erase(inst_it); break; } } @@ -535,6 +548,10 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( instruction_to_fuse->RemoveUser(this); } + // Reread the parameters in the computation. + const std::vector& fused_parameters_ = + fused_instructions_computation_->parameter_instructions(); + // Add each operand of the clone as an operand of the fusion instruction. A // complication is that some clone operands may already be operands of the // fusion instruction. @@ -557,19 +574,30 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // instruction. Add it as an operand and add a corresponding fused // parameter instruction. int64 param_no = fused_parameters_.size(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. Strip the leading "%" from the operand name + // to avoid a double %%. + string param_name = + StrCat(operand->name().substr(1), ".param_", param_no); std::unique_ptr param_instruction = - CreateParameter(param_no, operand->shape(), "fusion_param"); + CreateParameter(param_no, operand->shape(), param_name); param_instruction->parent_fusion_instruction_ = this; - fused_parameters_.push_back(param_instruction.get()); - fused_instructions_.push_back(std::move(param_instruction)); + fused_param = fused_instructions_computation_->AddParameter( + std::move(param_instruction)); AppendOperand(operand); - - fused_param = fused_instructions_.back().get(); } TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); } + for (HloComputation* computation : + instruction_to_fuse->called_computations()) { + if (std::find(called_computations_.begin(), called_computations_.end(), + computation) == called_computations_.end()) { + called_computations_.push_back(computation); + } + } + return clone; } @@ -578,58 +606,27 @@ RandomDistribution HloInstruction::random_distribution() const { return distribution_; } -namespace { - -// Adds any HloComputations this instruction calls directly to the given set. -void CalledComputationsInternal( - const HloInstruction& instruction, - std::set* called_computations) { - switch (instruction.opcode()) { - case HloOpcode::kCall: - case HloOpcode::kMap: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - called_computations->insert(instruction.to_apply()); - break; - case HloOpcode::kSelectAndScatter: - called_computations->insert(instruction.select()); - called_computations->insert(instruction.scatter()); - break; - case HloOpcode::kWhile: - called_computations->insert(instruction.while_condition()); - called_computations->insert(instruction.while_body()); - break; - case HloOpcode::kFusion: - for (const auto& fused_instruction : instruction.fused_instructions()) { - CalledComputationsInternal(*fused_instruction, called_computations); - } - break; - default: - break; - } -} - -} // namespace - -std::set HloInstruction::MakeCalledComputationsSet() const { - std::set called_computations; - CalledComputationsInternal(*this, &called_computations); - return called_computations; -} - void HloInstruction::CheckFusionInstruction() const { CHECK_EQ(opcode_, HloOpcode::kFusion); + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + const std::list>& fused_instructions_ = + fused_instructions_computation_->instructions(); // All instructions owned by this fusion instruction must be fused, and the // parent fusion instruction of the fused instructions must be 'this'. for (auto& instruction : fused_instructions_) { CHECK(instruction->IsFused()); CHECK_EQ(this, instruction->fusion_instruction()); + CHECK_EQ(fused_instructions_computation_.get(), instruction->parent()) + << instruction->ToString(); } // Fused root instruction and fused parameters must all be owned by the fusion // instruction. bool root_owned = false; + const std::vector& fused_parameters_ = fused_parameters(); + const HloInstruction* fused_root_ = fused_expression_root(); std::vector parameter_owned(fused_parameters_.size(), false); for (auto& instruction : fused_instructions_) { if (fused_root_ == instruction.get()) { @@ -695,7 +692,7 @@ void HloInstruction::CheckFusionInstruction() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->to_apply_ = computation; + instruction->called_computations_.push_back(computation); return instruction; } @@ -722,7 +719,8 @@ void HloInstruction::CheckFusionInstruction() const { } std::unique_ptr HloInstruction::CloneWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands) { + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands) { // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. @@ -733,6 +731,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCeil: case HloOpcode::kCopy: case HloOpcode::kExp: + case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: case HloOpcode::kLogicalNot: @@ -740,8 +739,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kSign: case HloOpcode::kSort: case HloOpcode::kTanh: - CHECK_EQ(operands.size(), 1); - return CreateUnary(shape, opcode_, operands[0]); + CHECK_EQ(new_operands.size(), 1); + return CreateUnary(shape, opcode_, new_operands[0]); // Binary ops. case HloOpcode::kAdd: case HloOpcode::kDivide: @@ -760,93 +759,93 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kRemainder: case HloOpcode::kLogicalAnd: case HloOpcode::kLogicalOr: - CHECK_EQ(operands.size(), 2); - return CreateBinary(shape, opcode_, operands[0], operands[1]); + CHECK_EQ(new_operands.size(), 2); + return CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: - CHECK_EQ(operands.size(), 3); - return CreateTernary(shape, opcode_, operands[0], operands[1], - operands[2]); + CHECK_EQ(new_operands.size(), 3); + return CreateTernary(shape, opcode_, new_operands[0], new_operands[1], + new_operands[2]); // Other supported ops. case HloOpcode::kBroadcast: - CHECK_EQ(operands.size(), 1); - return CreateBroadcast(shape, operands[0], dimensions_); + CHECK_EQ(new_operands.size(), 1); + return CreateBroadcast(shape, new_operands[0], dimensions_); case HloOpcode::kCall: - return CreateCall(shape, operands, to_apply_); + return CreateCall(shape, new_operands, to_apply()); case HloOpcode::kCustomCall: - return CreateCustomCall(shape, operands, custom_call_target_); + return CreateCustomCall(shape, new_operands, custom_call_target_); case HloOpcode::kConcatenate: - return CreateConcatenate(shape, operands, dimensions(0)); + return CreateConcatenate(shape, new_operands, dimensions(0)); case HloOpcode::kConvert: - CHECK_EQ(operands.size(), 1); - return CreateConvert(shape, operands[0]); + CHECK_EQ(new_operands.size(), 1); + return CreateConvert(shape, new_operands[0]); case HloOpcode::kConvolution: - CHECK_EQ(operands.size(), 2); - return CreateConvolve(shape, operands[0], operands[1], *window_, + CHECK_EQ(new_operands.size(), 2); + return CreateConvolve(shape, new_operands[0], new_operands[1], *window_, *convolution_dimension_numbers_); case HloOpcode::kCrossReplicaSum: - CHECK_EQ(operands.size(), 1); - return CreateCrossReplicaSum(shape, operands[0]); + CHECK_EQ(new_operands.size(), 1); + return CreateCrossReplicaSum(shape, new_operands[0]); case HloOpcode::kGetTupleElement: - CHECK_EQ(operands.size(), 1); - return CreateGetTupleElement(shape, operands[0], tuple_index()); + CHECK_EQ(new_operands.size(), 1); + return CreateGetTupleElement(shape, new_operands[0], tuple_index()); case HloOpcode::kMap: - return CreateMap(shape, operands, to_apply_); + return CreateMap(shape, new_operands, to_apply()); case HloOpcode::kPad: - CHECK_EQ(operands.size(), 2); - return CreatePad(shape, operands[0], operands[1], *padding_config_); + CHECK_EQ(new_operands.size(), 2); + return CreatePad(shape, new_operands[0], new_operands[1], + *padding_config_); case HloOpcode::kReduce: - CHECK_EQ(operands.size(), 2); - return CreateReduce(shape, operands[0], operands[1], dimensions_, - to_apply_); + CHECK_EQ(new_operands.size(), 2); + return CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, + to_apply()); case HloOpcode::kReduceWindow: - CHECK_EQ(operands.size(), 2); - return CreateReduceWindow(shape, operands[0], operands[1], *window_, - to_apply_); + CHECK_EQ(new_operands.size(), 2); + return CreateReduceWindow(shape, new_operands[0], new_operands[1], + *window_, to_apply()); case HloOpcode::kSelectAndScatter: - CHECK_EQ(operands.size(), 3); - return CreateSelectAndScatter(shape, operands[0], select_, *window_, - operands[1], operands[2], scatter_); - case HloOpcode::kRecv: - CHECK_EQ(operands.size(), 0); - return CreateRecv(shape, channel_id_); + CHECK_EQ(new_operands.size(), 3); + return CreateSelectAndScatter(shape, new_operands[0], select(), *window_, + new_operands[1], new_operands[2], + scatter()); case HloOpcode::kReverse: - CHECK_EQ(operands.size(), 1); - return CreateReverse(shape, operands[0], dimensions_); + CHECK_EQ(new_operands.size(), 1); + return CreateReverse(shape, new_operands[0], dimensions_); case HloOpcode::kRng: - return CreateRng(shape, distribution_, operands); + return CreateRng(shape, distribution_, new_operands); case HloOpcode::kReshape: - CHECK_EQ(operands.size(), 1); - return CreateReshape(shape, operands[0]); - case HloOpcode::kSend: - CHECK_EQ(operands.size(), 1); - return CreateSend(operands[0], channel_id_); + CHECK_EQ(new_operands.size(), 1); + return CreateReshape(shape, new_operands[0]); case HloOpcode::kSlice: - CHECK_EQ(operands.size(), 1); - return CreateSlice(shape, operands[0], slice_starts_, slice_limits_); + CHECK_EQ(new_operands.size(), 1); + return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, + slice_strides_); case HloOpcode::kDynamicSlice: - return CreateDynamicSlice(shape, operands[0], operands[1], + return CreateDynamicSlice(shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); case HloOpcode::kDynamicUpdateSlice: - CHECK_EQ(operands.size(), 3); - return CreateDynamicUpdateSlice(shape, operands[0], operands[1], - operands[2]); + CHECK_EQ(new_operands.size(), 3); + return CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], + new_operands[2]); case HloOpcode::kTranspose: - CHECK_EQ(operands.size(), 1); - return CreateTranspose(shape, operands[0], dimensions_); + CHECK_EQ(new_operands.size(), 1); + return CreateTranspose(shape, new_operands[0], dimensions_); case HloOpcode::kTuple: - return CreateTuple(operands_); + return CreateTuple(new_operands); case HloOpcode::kWhile: - CHECK_EQ(operands.size(), 1); - return CreateWhile(shape, condition_, body_, operands[0]); + CHECK_EQ(new_operands.size(), 1); + return CreateWhile(shape, while_condition(), while_body(), + new_operands[0]); case HloOpcode::kConstant: return CreateConstant(LiteralUtil::CloneToUnique(*literal_)); case HloOpcode::kFusion: - return CloneFusionWithNewOperands(shape, operands); + return CloneFusionWithNewOperands(shape, new_operands); case HloOpcode::kParameter: return CreateParameter(parameter_number_, shape, parameter_name_); // Unsupported ops for cloning. + case HloOpcode::kRecv: + case HloOpcode::kSend: case HloOpcode::kUpdate: case HloOpcode::kIndex: case HloOpcode::kInfeed: @@ -856,16 +855,55 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } } -std::unique_ptr HloInstruction::Clone() { +HloInstruction::~HloInstruction() {} + +std::unique_ptr HloInstruction::Clone(const string& suffix) { std::unique_ptr clone = CloneWithNewOperands(shape_, operands_); - clone->name_ = name() + ".clone"; + if (suffix.empty()) { + clone->name_ = name(); + } else { + // If an instruction is cloned multiple times avoid names like + // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric + // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the + // clone of foo.suffix2 is named foo.suffix3 and so on. + const string dot_suffix = "." + suffix; + size_t index = name().rfind(dot_suffix); + if (index == string::npos) { + // Existing name does not include ".suffix". + clone->name_ = name() + dot_suffix; + } else { + // Existing name includes ".suffix". Determine if substring after + // ".suffix" is numeric and should be replaced with an incremented number. + string after_suffix = name().substr(index + dot_suffix.size()); + if (after_suffix.empty()) { + // Existing name ends in ".suffix". New name should end in ".suffix2". + clone->name_ = name() + "2"; + } else { + // If names ends with .suffix[0-9]+ then replace with a suffix with the + // numeric value incremented. + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + clone->name_ = + StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1); + } else { + // Substring after ".suffix" is non-numeric. + clone->name_ = name() + dot_suffix; + } + } + } + } + clone->set_parent(parent()); + clone->set_metadata(metadata_); return clone; } std::unique_ptr HloInstruction::CloneFusionWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands) { CHECK_EQ(opcode_, HloOpcode::kFusion); + CHECK(parent() != nullptr); + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); auto new_instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); @@ -879,6 +917,11 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( // Create the list of fused parameters by mapping through the cloned, // fused instructions. std::vector new_fused_parameters; + const std::vector& fused_parameters_ = + fused_instructions_computation_->parameter_instructions(); + const std::list>& fused_instructions_ = + fused_instructions_computation_->instructions(); + for (HloInstruction* old_fused_parameter : fused_parameters_) { new_fused_instructions.push_back(old_fused_parameter->Clone()); HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); @@ -905,16 +948,24 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( old_fused_instruction->CloneWithNewOperands( old_fused_instruction->shape(), new_operands)); HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); + new_fused_instruction->set_parent(parent()); new_fused_instruction->parent_fusion_instruction_ = new_instruction.get(); InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); } + new_instruction->fusion_kind_ = fusion_kind_; + auto computation_builder = HloComputation::Builder( + fused_instructions_computation_->name() + ".clone", true); // We iterated the fusion instructions in reverse post order which means // that we must reverse our new list of fusion instructions. - std::reverse(new_fused_instructions.begin(), new_fused_instructions.end()); - new_instruction->fusion_kind_ = fusion_kind_; - new_instruction->fused_instructions_ = std::move(new_fused_instructions); - new_instruction->fused_parameters_ = std::move(new_fused_parameters); - new_instruction->fused_root_ = FindOrDie(old_to_new, fused_root_); + for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); + new_fused_instruction_iter != new_fused_instructions.rend(); + ++new_fused_instruction_iter) { + computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); + } + auto fused_root_ = fused_expression_root(); + new_instruction->fused_instructions_computation_ = + computation_builder.Build(FindOrDie(old_to_new, fused_root_)); + new_instruction->set_parent(parent()); new_instruction->CheckFusionInstruction(); return new_instruction; } @@ -969,12 +1020,43 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand"; } +Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { + TF_RET_CHECK(instruction->parent() == parent()); + if (std::find(control_successors_.begin(), control_successors_.end(), + instruction) == control_successors_.end()) { + control_successors_.push_back(instruction); + TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(), + instruction->control_predecessors_.end(), + this) == instruction->control_predecessors_.end()); + instruction->control_predecessors_.push_back(this); + } + return Status::OK(); +} + +Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) { + auto succ_it = std::find(control_successors_.begin(), + control_successors_.end(), instruction); + TF_RET_CHECK(succ_it != control_successors_.end()); + control_successors_.erase(succ_it); + auto pred_it = std::find(instruction->control_predecessors_.begin(), + instruction->control_predecessors_.end(), this); + TF_RET_CHECK(pred_it != instruction->control_predecessors_.end()); + instruction->control_predecessors_.erase(succ_it); + + return Status::OK(); +} + void HloInstruction::AppendOperand(HloInstruction* operand) { operands_.push_back(operand); operand->AddUser(this); } -void HloInstruction::AddUser(HloInstruction* user) { users_.insert(user); } +void HloInstruction::AddUser(HloInstruction* user) { + if (!ContainsKey(user_set_, user)) { + user_set_.insert(user); + users_.push_back(user); + } +} bool HloInstruction::IsConstant() const { return opcode_ == HloOpcode::kConstant; @@ -989,14 +1071,6 @@ bool HloInstruction::HasConstantOperand() const { return false; } -void HloInstruction::AddControlPredecessor(HloInstruction* instruction) { - control_predecessors_.insert(instruction); -} - -void HloInstruction::AddControlSuccessor(HloInstruction* instruction) { - control_successors_.insert(instruction); -} - bool HloInstruction::Identical( const HloInstruction& other, std::function @@ -1012,7 +1086,7 @@ bool HloInstruction::Identical( // general, there is no need to check shape because shape is inferred from the // shape of the operands. if (opcode() != other.opcode() || - !ContainersEqual(operands(), other.operands(), eq_operands)) { + !ContainersEqual(operands(), other.operands(), std::move(eq_operands))) { return false; } @@ -1033,6 +1107,7 @@ bool HloInstruction::Identical( case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: + case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLogicalAnd: @@ -1156,9 +1231,14 @@ bool HloInstruction::IsRank2Transpose() const { } void HloInstruction::RemoveUser(HloInstruction* user) { - auto user_it = users_.find(user); - CHECK(user_it != users_.end()); - users_.erase(user_it); + auto set_it = user_set_.find(user); + CHECK(set_it != user_set_.end()); + user_set_.erase(set_it); + // This is linear in the number of the users, but a vector provides a stable + // iteration order and much faster traversal. + auto vec_it = std::find(users_.begin(), users_.end(), user); + CHECK(vec_it != users_.end()); + users_.erase(vec_it); } Status HloInstruction::ReplaceUseWith(HloInstruction* user, @@ -1167,15 +1247,12 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); - auto user_it = std::find(users_.begin(), users_.end(), user); - TF_RET_CHECK(user_it != users_.end()) - << "Instruction " << user->name() << " not a use of instruction " - << name(); - users_.erase(user_it); VLOG(3) << "Replacing uses of " << name() << " in " << user->name() << " with " << new_producer->name(); + RemoveUser(user); + TF_RET_CHECK( std::count(user->operands_.begin(), user->operands_.end(), this) >= 0); std::replace(user->operands_.begin(), user->operands_.end(), this, @@ -1207,30 +1284,37 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, } Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { - // We can't use range-based loop because the iterator is invalidated by call - // to ReplaceUseWith. - for (auto user = users_.begin(); user != users_.end();) { - auto this_user = user; - user++; - // It's possible that new_producer is a user of this instruction as might - // be the case when replacing an instruction with a kCopy of itself. In - // this case, don't do the replacement to avoid creating a cycle in the - // graph. - if (*this_user != new_producer) { - TF_RETURN_IF_ERROR(ReplaceUseWith(*this_user, new_producer)); + bool new_producer_is_user = false; + for (HloInstruction* user : users()) { + if (user == new_producer) { + // It's possible that new_producer is a user of this instruction as might + // be the case when replacing an instruction with a kCopy of itself. In + // this case, don't do the replacement to avoid creating a cycle in the + // graph. new_producer remains the only user of this instruction. + new_producer_is_user = true; + } else { + std::replace(user->operands_.begin(), user->operands_.end(), this, + new_producer); + new_producer->AddUser(user); } } + users_.clear(); + user_set_.clear(); + if (new_producer_is_user) { + AddUser(new_producer); + } + return Status::OK(); } void HloInstruction::DetachFromOperands() { CHECK_EQ(0, user_count()); - // An intruction may be repeated as an operand. To avoid calling RemoveUser + // An instruction may be repeated as an operand. To avoid calling RemoveUser // twice on the same operand, keep a set of already detached operands. std::set detached_operands; for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { HloInstruction* operand = operands_[operand_num]; - if (detached_operands.count(operand) == 0) { + if (!ContainsKey(detached_operands, operand)) { operand->RemoveUser(this); detached_operands.insert(operand); } @@ -1244,22 +1328,29 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - return to_apply_; + CHECK_EQ(called_computations_.size(), 1); + return called_computations_[0]; default: - LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString(); + LOG(FATAL) << "Invalid opcode for to_apply(): " + << HloOpcodeString(opcode()); } } void HloInstruction::set_to_apply(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); switch (opcode_) { case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - to_apply_ = computation; + CHECK_EQ(called_computations_.size(), 1); + called_computations_[0] = computation; break; default: - LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString(); + LOG(FATAL) << "Invalid opcode for to_apply(): " + << HloOpcodeString(opcode()); } } @@ -1275,55 +1366,75 @@ const string& HloInstruction::outfeed_config() const { HloComputation* HloInstruction::while_condition() const { CHECK_EQ(HloOpcode::kWhile, opcode_); - return condition_; + return called_computations_[kConditionComputationIndex]; } HloComputation* HloInstruction::while_body() const { CHECK_EQ(HloOpcode::kWhile, opcode_); - return body_; + return called_computations_[kBodyComputationIndex]; } void HloInstruction::set_while_condition(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kWhile, opcode_); - condition_ = computation; + called_computations_[kConditionComputationIndex] = computation; } void HloInstruction::set_while_body(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kWhile, opcode_); - body_ = computation; + called_computations_[kBodyComputationIndex] = computation; } HloComputation* HloInstruction::select() const { CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return select_; + return called_computations_[kSelectComputationIndex]; } HloComputation* HloInstruction::scatter() const { CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return scatter_; + return called_computations_[kScatterComputationIndex]; } void HloInstruction::set_select(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - select_ = computation; + called_computations_[kSelectComputationIndex] = computation; } void HloInstruction::set_scatter(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - scatter_ = computation; + called_computations_[kScatterComputationIndex] = computation; } string HloInstruction::SignatureString() const { - string operands = tensorflow::str_util::Join( - operands_, ", ", [](string* out, HloInstruction* operand) { - tensorflow::strings::StrAppend( - out, ShapeUtil::HumanString(operand->shape())); + string operands = + Join(operands_, ", ", [](string* out, HloInstruction* operand) { + StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); - return tensorflow::strings::StrCat("(", operands, ") -> ", - ShapeUtil::HumanString(shape())); + return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } -string HloInstruction::ToString(bool compact_operands) const { +string HloInstruction::ExtendedOpcodeStr() const { + string opc_name = HloOpcodeString(opcode()); + HloOpcode opc = opcode(); + if (HloOpcode::kFusion == opc) { + opc_name += ":" + xla::ToString(fusion_kind()); + } + return opc_name; +} + +string HloInstruction::ToString(bool compact_operands, + bool include_metadata) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. @@ -1337,120 +1448,140 @@ string HloInstruction::ToString(bool compact_operands) const { // Concatenate elements in "v" with spaces separating them, but ignoring // empty entries. for (const auto& s : v) { - if (s.empty()) continue; - tensorflow::strings::StrAppend(&operands, (first ? "" : " "), s); + if (s.empty()) { + continue; + } + StrAppend(&operands, (first ? "" : " "), s); first = false; } } else { // Do not show large constants. operands = "{...}"; } + } else if (opcode() == HloOpcode::kParameter) { + operands = Printf("%lld", parameter_number_); } else { tensorflow::gtl::ArraySlice slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; if (compact_operands && slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } - operands = tensorflow::str_util::Join( - slice, ", ", [&](string* out, HloInstruction* operand) { - *out += ShapeUtil::HumanStringWithLayout(operand->shape()); - if (!compact_operands) { - tensorflow::strings::StrAppend(out, " ", operand->name()); - } - }); + operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + *out += ShapeUtil::HumanStringWithLayout(operand->shape()); + if (!compact_operands) { + StrAppend(out, " ", operand->name()); + } + }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { - tensorflow::strings::StrAppend(&operands, ", ...(+", remaining, ")"); + StrAppend(&operands, ", ...(+", remaining, ")"); } } string extra; if (CanHaveDimensionsField()) { - tensorflow::strings::StrAppend( - &extra, ", dimensions={", tensorflow::str_util::Join(dimensions(), ","), - "}"); + StrAppend(&extra, ", dimensions={", Join(dimensions(), ","), "}"); } if (window_ != nullptr) { - tensorflow::strings::StrAppend(&extra, ", ", - window_util::ToString(*window_)); + StrAppend(&extra, ", ", window_util::ToString(*window_)); } if (padding_config_ != nullptr) { - tensorflow::strings::StrAppend( - &extra, ", padding=", padding_config_->ShortDebugString()); + StrAppend(&extra, ", padding=", padding_config_->ShortDebugString()); } if (!slice_starts_.empty() && !slice_limits_.empty()) { std::vector bounds; + bounds.reserve(slice_starts_.size()); for (int i = 0; i < slice_starts_.size(); ++i) { - bounds.push_back(tensorflow::strings::StrCat("[", slice_starts_[i], ":", - slice_limits_[i], "]")); + bounds.push_back( + StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]")); } - tensorflow::strings::StrAppend( - &extra, ", slice={", tensorflow::str_util::Join(bounds, ", "), "}"); + StrAppend(&extra, ", slice={", Join(bounds, ", "), "}"); } + if (convolution_dimension_numbers_ != nullptr) { - const auto& dnums = *convolution_dimension_numbers_; - - // Show the given dimension labels in order of major to minor based on the - // shape's layout. - const auto append_dims = [&](const std::vector& dims, - const Shape& shape) { - CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - for (int64 logical = 0; logical < dims.size(); ++logical) { - int64 physical = logical; - if (!shape.layout().minor_to_major().empty()) { - physical = LayoutUtil::Major(shape.layout(), logical); - } - extra += dims[physical]; - } - }; - - // lhs_dims[i] is the symbol of the logical dimension i for the lhs - // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". - std::vector lhs_dims(2 + dnums.spatial_dimensions().size()); - lhs_dims[dnums.batch_dimension()] = 'b'; - lhs_dims[dnums.feature_dimension()] = 'f'; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - lhs_dims[dnums.spatial_dimensions(i)] = tensorflow::strings::StrCat(i); - } - - std::vector rhs_dims(2 + dnums.kernel_spatial_dimensions().size()); - rhs_dims[dnums.kernel_input_feature_dimension()] = "i"; - rhs_dims[dnums.kernel_output_feature_dimension()] = "o"; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - rhs_dims[dnums.kernel_spatial_dimensions(i)] = - tensorflow::strings::StrCat(i); - } - - extra += " dims: "; - append_dims(lhs_dims, operands_.at(0)->shape()); - extra += "_"; - append_dims(rhs_dims, operands_.at(1)->shape()); - extra += "->"; - append_dims(lhs_dims, shape()); - } - if (to_apply_ != nullptr) { - tensorflow::strings::StrAppend(&extra, ", computation=", to_apply_->name()); + StrAppend(&extra, ", ", ConvolutionDimensionNumbersToString()); } + if (opcode() == HloOpcode::kWhile) { - tensorflow::strings::StrAppend(&extra, - ", condition=", while_condition()->name()); - tensorflow::strings::StrAppend(&extra, ", body=", while_body()->name()); + StrAppend(&extra, ", condition=", while_condition()->name()); + StrAppend(&extra, ", body=", while_body()->name()); + } else if (opcode() == HloOpcode::kSelectAndScatter) { + StrAppend(&extra, ", select=", select()->name()); + StrAppend(&extra, ", scatter=", scatter()->name()); + } else if (!called_computations().empty()) { + StrAppend(&extra, ", calls=", + Join(called_computations(), ", ", + [](string* out, const HloComputation* computation) { + StrAppend(out, computation->name()); + })); } + if (opcode() == HloOpcode::kGetTupleElement) { - tensorflow::strings::StrAppend(&extra, ", index=", tuple_index()); + StrAppend(&extra, ", index=", tuple_index()); } - return tensorflow::strings::Printf( - "%s = %s %s(%s)%s", name().c_str(), - ShapeUtil::HumanStringWithLayout(shape()).c_str(), - HloOpcodeString(opcode()).c_str(), operands.c_str(), extra.c_str()); + if (include_metadata && + (!metadata_.op_type().empty() || !metadata_.op_name().empty() || + !metadata_.source_file().empty())) { + StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); + } + + return Printf("%s = %s %s(%s)%s", name().c_str(), + ShapeUtil::HumanStringWithLayout(shape()).c_str(), + ExtendedOpcodeStr().c_str(), operands.c_str(), extra.c_str()); } string HloInstruction::ToShortString() const { - return tensorflow::strings::Printf( - "%s = %s(%s)", name().c_str(), HloOpcodeString(opcode()).c_str(), - tensorflow::str_util::Join(operands_, ", ", [](string* out, - HloInstruction* operand) { - tensorflow::strings::StrAppend(out, operand->name()); - }).c_str()); + return Printf("%s = %s(%s)", name().c_str(), + HloOpcodeString(opcode()).c_str(), + Join(operands_, ", ", + [](string* out, HloInstruction* operand) { + StrAppend(out, operand->name()); + }) + .c_str()); +} + +HloInstructionProto HloInstruction::ToProto() const { + HloInstructionProto proto; + proto.set_name(name_); + proto.set_opcode(HloOpcodeString(opcode_)); + *proto.mutable_shape() = shape_; + for (const HloInstruction* operand : operands_) { + *proto.add_operand_names() = operand->name(); + } + for (const HloInstruction* control : control_predecessors_) { + *proto.add_control_predecessor_names() = control->name(); + } + for (const HloComputation* computation : called_computations_) { + *proto.add_called_computation_names() = computation->name(); + } + *proto.mutable_metadata() = metadata_; + switch (opcode_) { + case HloOpcode::kConstant: + *proto.mutable_literal() = literal_->ToProto(); + break; + case HloOpcode::kParameter: + proto.set_parameter_number(parameter_number_); + proto.set_parameter_name(parameter_name_); + break; + case HloOpcode::kFusion: { + HloComputationProto* proto_fused_computation = + proto.mutable_fused_instructions_computation(); + proto_fused_computation->set_name(FullyQualifiedName()); + + // Fill in fused instructions. Note that fused_instructions() returns in + // reverse post-order (i.e. root first), so we reverse to get post-order. + for (auto fused_it = fused_instructions().rbegin(); + fused_it != fused_instructions().rend(); ++fused_it) { + HloInstructionProto fused_proto = (*fused_it)->ToProto(); + proto_fused_computation->add_instructions()->Swap(&fused_proto); + } + break; + } + case HloOpcode::kGetTupleElement: + proto.set_tuple_index(tuple_index_); + break; + default: {} // Nothing to do + } + return proto; } string HloInstruction::ToCategory() const { @@ -1482,10 +1613,22 @@ string HloInstruction::ToCategory() const { return "rank-1-broadcast binary fusion"; } } - if (IsElementwise()) { - return "elementwise fusion"; - } else { - return "non-elementwise fusion"; + switch (fusion_kind()) { + case FusionKind::kLoop: + if (IsElementwise()) { + return "elementwise fusion"; + } else { + return "non-elementwise fusion"; + } + case FusionKind::kInput: + return "input fusion"; + case FusionKind::kOutput: + return "output fusion"; + case FusionKind::kTransposeDot: + return "dot fusion"; + case FusionKind::kConvBackwardFilter: + case FusionKind::kConvBackwardInput: + return "convolution fusion"; } } @@ -1496,16 +1639,24 @@ string HloInstruction::ToCategory() const { return HloOpcodeString(opcode()); } +string HloInstruction::FullyQualifiedName() const { + if (IsFused()) { + return StrCat(fusion_instruction()->parent()->name(), + "::", fusion_instruction()->name(), "::", name_); + } + return StrCat(parent_->name(), "::", name_); +} + HloInstruction* HloInstruction::tracing() const { return trace_instruction_; } void HloInstruction::set_tracing(HloInstruction* trace_instruction) { trace_instruction_ = trace_instruction; } -const string& HloInstruction::tracing_tag() const { +string HloInstruction::TracingTag() const { CHECK_EQ(HloOpcode::kTrace, opcode()); CHECK(literal_ != nullptr); - return literal_->u8s(); + return literal_->u8s_string(); } bool HloInstruction::IsFused() const { @@ -1520,7 +1671,6 @@ bool HloInstruction::IsFusable() const { // Some kinds of instructions don't make sense to fuse. switch (opcode_) { - case HloOpcode::kFusion: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kParameter: @@ -1528,11 +1678,20 @@ bool HloInstruction::IsFusable() const { case HloOpcode::kSend: case HloOpcode::kRecv: return false; + // Only fuse Rng if it is used once, otherwise the random numbers generated + // will be different in each fusion. + case HloOpcode::kRng: + return users_.size() == 1; default: return true; } } +HloComputation* HloInstruction::fused_instructions_computation() const { + CHECK_EQ(opcode_, HloOpcode::kFusion); + return fused_instructions_computation_.get(); +} + HloInstruction* HloInstruction::fusion_instruction() const { CHECK(IsFused()); return parent_fusion_instruction_; @@ -1540,20 +1699,32 @@ HloInstruction* HloInstruction::fusion_instruction() const { HloInstruction* HloInstruction::fused_expression_root() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_root_; + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + return fused_instructions_computation_->root_instruction(); } HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_GE(parameter_number, 0); - CHECK_LT(parameter_number, fused_parameters_.size()); - return fused_parameters_[parameter_number]; + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + return fused_instructions_computation_->parameter_instruction( + parameter_number); +} + +const std::vector& HloInstruction::fused_parameters() const { + CHECK_EQ(opcode_, HloOpcode::kFusion); + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + return fused_instructions_computation_->parameter_instructions(); } const std::list>& HloInstruction::fused_instructions() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_; + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + return fused_instructions_computation_->instructions(); } HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) @@ -1619,16 +1790,16 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kTuple: return visitor->HandleTuple(this, operands_); case HloOpcode::kMap: - return visitor->HandleMap(this, operands_, to_apply_, {}); + return visitor->HandleMap(this, operands_, to_apply(), {}); case HloOpcode::kClamp: return visitor->HandleClamp(this, operands_[0], operands_[1], operands_[2]); case HloOpcode::kReduce: return visitor->HandleReduce(this, operands_[0], operands_[1], - dimensions_, to_apply_); + dimensions_, to_apply()); case HloOpcode::kReduceWindow: return visitor->HandleReduceWindow(this, operands_[0], window(), - to_apply_); + to_apply()); case HloOpcode::kSelectAndScatter: return visitor->HandleSelectAndScatter(this); case HloOpcode::kNegate: @@ -1643,6 +1814,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleLog(this, operands_[0]); case HloOpcode::kTanh: return visitor->HandleTanh(this, operands_[0]); + case HloOpcode::kIsFinite: + return visitor->HandleIsFinite(this, operands_[0]); case HloOpcode::kLogicalNot: return visitor->HandleLogicalNot(this, operands_[0]); case HloOpcode::kBitcast: @@ -1660,7 +1833,7 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kSlice: return visitor->HandleSlice(this, operands_[0]); case HloOpcode::kDynamicSlice: - return visitor->HandleDynamicSlice(this, operands_); + return visitor->HandleDynamicSlice(this, operands_[0], operands_[1]); case HloOpcode::kDynamicUpdateSlice: return visitor->HandleDynamicUpdateSlice(this, operands_[0], operands_[1], operands_[2]); @@ -1673,11 +1846,11 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kRng: return visitor->HandleRng(this, distribution_); case HloOpcode::kWhile: - return visitor->HandleWhile(this, operands_[0], condition_, body_); + return visitor->HandleWhile(this); case HloOpcode::kFusion: return visitor->HandleFusion(this); case HloOpcode::kCall: - return visitor->HandleCall(this, operands_, to_apply_); + return visitor->HandleCall(this); case HloOpcode::kCustomCall: return visitor->HandleCustomCall(this, operands_, custom_call_target_); case HloOpcode::kSend: @@ -1695,7 +1868,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { HloOpcodeString(opcode_).c_str()); } -Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor) { +Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor, + const CompareFunction* operand_order, + bool ignore_control_predecessors) { // Do not visit this HLO node again if it is already visited. if (visitor->DidVisit(*this)) { VLOG(3) << "Not visiting HLO " << name() << " as it was already visited."; @@ -1710,16 +1885,41 @@ Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor) { } visitor->SetVisiting(*this); - for (auto operand : operands_) { + // Sort operands, if an ordering was provided. 'temp_sorted_operands' must + // live at this scope, since 'operands' will point to it if the operands are + // sorted. The purpose of the 'operands' pointer is to avoid copying the + // operands in the common case where the operands are not sorted. + std::vector* operands = &operands_; + std::vector temp_sorted_operands; + if (operand_order != nullptr) { + temp_sorted_operands = operands_; + std::sort(temp_sorted_operands.begin(), temp_sorted_operands.end(), + *operand_order); + operands = &temp_sorted_operands; + } + for (HloInstruction* operand : *operands) { VLOG(3) << "Going to visit HLO " << operand->name() << " as operand of HLO " << name(); - TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor)); + TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor, operand_order, + ignore_control_predecessors)); } - for (auto control_predecessor : control_predecessors_) { - VLOG(3) << "Going to visit HLO " << control_predecessor->name() - << " as a control predecessor of HLO " << name(); - TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal(visitor)); + if (!ignore_control_predecessors) { + // This uses the same pointer/vector sorting to avoid extra copies as above. + std::vector* predecessors = &control_predecessors_; + std::vector temp_sorted_predecessors; + if (operand_order != nullptr) { + temp_sorted_predecessors = control_predecessors_; + std::sort(temp_sorted_predecessors.begin(), + temp_sorted_predecessors.end(), *operand_order); + predecessors = &temp_sorted_predecessors; + } + for (HloInstruction* control_predecessor : *predecessors) { + VLOG(3) << "Going to visit HLO " << control_predecessor->name() + << " as a control predecessor of HLO " << name(); + TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal( + visitor, operand_order, ignore_control_predecessors)); + } } TF_RETURN_IF_ERROR(visitor->Preprocess(this)); @@ -1729,18 +1929,27 @@ Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor) { return visitor->Postprocess(this); } -Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit) { +Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, + bool ignore_control_predecessors) { VLOG(2) << "HloInstruction::Accept(" << name() << ")"; - auto status = AcceptInternal(visitor); - if (!status.ok()) { - return status; - } - + TF_RETURN_IF_ERROR( + AcceptInternal(visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { - return visitor->FinishVisit(this); - } else { - return Status::OK(); + TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } + return Status::OK(); +} + +Status HloInstruction::AcceptWithOperandOrder( + DfsHloVisitor* visitor, const CompareFunction& operand_order, + bool call_finish_visit) { + VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; + TF_RETURN_IF_ERROR(AcceptInternal(visitor, &operand_order, + /*ignore_control_predecessors=*/false)); + if (call_finish_visit) { + TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); + } + return Status::OK(); } namespace { @@ -1761,7 +1970,7 @@ bool OrderIsTopologicalSort(const std::vector& order) { // ops). for (auto* instruction : order) { for (auto* operand : instruction->operands()) { - if (order_position.count(operand) == 0 || + if (!ContainsKey(order_position, operand) || order_position.at(operand) >= order_position.at(instruction)) { return false; } @@ -1773,7 +1982,8 @@ bool OrderIsTopologicalSort(const std::vector& order) { } // namespace -Status HloInstruction::Accept(FunctionVisitor::VisitorFunction visitor_func) { +Status HloInstruction::Accept( + const FunctionVisitor::VisitorFunction& visitor_func) { FunctionVisitor visitor(visitor_func); return this->Accept(&visitor); } @@ -1791,7 +2001,7 @@ Status HloInstruction::AcceptOrdered( })); for (auto* const_instruction : order) { - if (predecessors.count(const_instruction) == 0) { + if (!ContainsKey(predecessors, const_instruction)) { // Instruction is not a predecessors of 'this'. continue; } @@ -1817,6 +2027,12 @@ Status HloInstruction::AcceptOrdered( return visitor->FinishVisit(this); } +const Shape& HloInstruction::outfeed_shape() const { + DCHECK_EQ(opcode_, HloOpcode::kOutfeed); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); + return outfeed_shape_; +} + const Shape& HloInstruction::shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); return shape_; @@ -1846,6 +2062,7 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCopy: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kLogicalNot: case HloOpcode::kNegate: @@ -1879,6 +2096,7 @@ bool HloInstruction::IsElementwise() const { return true; // Other operations. + case HloOpcode::kRng: case HloOpcode::kMap: return true; case HloOpcode::kFusion: @@ -1932,7 +2150,7 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { HloInstruction* operand = worklist.front(); worklist.pop_front(); for (HloInstruction* user : operand->users()) { - if (visited.count(user)) { + if (ContainsKey(visited, user)) { continue; } if (user->IsElementwise() || @@ -1947,6 +2165,70 @@ bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { return true; } +// A helper class for memoized, recursive computation of HloOpcode::kFusion +// in HloInstruction::OperandElementUse below. +class HloInstruction::FusionReusesParamElements { + public: + using UseKind = HloInstruction::UseKind; + + // We could rather iterate backwards thru fused_instructions_ here, as it is + // in reverse postorder, and compute whether each fused instruction reuses the + // value of this parameter, which would save stack space but not allow us to + // finish early if we find a reuse. + static UseKind Compute(int64 i, const HloInstruction& hlo) { + tensorflow::gtl::FlatMap memoization_cache; + return ComputeInternal(i, hlo, &memoization_cache); + } + + private: + static UseKind ComputeInternal( + int64 i, const HloInstruction& hlo, + tensorflow::gtl::FlatMap* cache) { + if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) { + return UseKind::kUse; + } + + auto p = cache->emplace(&hlo, UseKind{}); + auto value_it = p.first; + const bool key_is_new = p.second; + + if (key_is_new) { + for (int64 j = 0; j < hlo.operands_.size(); ++j) { + UseKind old_val = value_it->second; + + // The next operation invalidates iterators. + UseKind new_val = + Plus(old_val, std::min(hlo.OperandElementUse(j), + ComputeInternal(i, *hlo.operand(j), cache))); + + // Re-acquire the iterator. We could work harder to do this only if + // absolutely necessary, but this code is not hot enough to warrant + // that. + value_it = cache->find(&hlo); + value_it->second = new_val; + } + } + return value_it->second; + } + + // Fold operation for UseKinds. + static UseKind Plus(UseKind a, UseKind b) { + if (a == UseKind::kNoUse) { + return b; + } else if (b == UseKind::kNoUse) { + return a; + } else if (a == UseKind::kReuse || b == UseKind::kReuse) { + return UseKind::kReuse; + } else if (a == UseKind::kUsePermutingElements || + b == UseKind::kUsePermutingElements) { + return UseKind::kReuse; + } else { + CHECK(a == UseKind::kUse && b == UseKind::kUse); + return UseKind::kUse; + } + } +}; + HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { switch (opcode_) { case HloOpcode::kBitcast: @@ -1961,69 +2243,14 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { // Pad reuses the padding value but not the padded array elements. // Reduce reuses the init value but not the operand array elements. return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; - case HloOpcode::kFusion: { - tensorflow::gtl::FlatMap cache; - // We could rather iterate backwards thru fused_instructions_ here, as it - // is in reverse postorder, and compute whether each fused instruction - // reuses the value of this parameter, which would save stack space but - // not allow us to finish early if we find a reuse. - std::function reuses_parameter_elements = - [i, &cache, &reuses_parameter_elements](const HloInstruction& hlo) { - auto plus = [](const UseKind& a, const UseKind& b) { - if (a == UseKind::kNoUse) return b; - if (b == UseKind::kNoUse) return a; - if (a == UseKind::kReuse || b == UseKind::kReuse) { - return UseKind::kReuse; - } - if (a == UseKind::kUsePermutingElements || - b == UseKind::kUsePermutingElements) { - return UseKind::kReuse; - } - CHECK(UseKind::kUse == a && UseKind::kUse == b); - return UseKind::kUse; - }; - - if (hlo.opcode_ == HloOpcode::kParameter && - hlo.parameter_number_ == i) { - return UseKind::kUse; - } - if (cache.count(&hlo) == 0) { - for (int64 j = 0; j < hlo.operands_.size(); ++j) { - UseKind old = cache[&hlo]; - UseKind updated = plus( - old, std::min(hlo.OperandElementUse(j), - reuses_parameter_elements(*hlo.operand(j)))); - cache[&hlo] = updated; - } - } - return cache[&hlo]; - }; - return reuses_parameter_elements(*fused_root_); - } + case HloOpcode::kFusion: + // Uses the memoizing, recursive computation defined above. + return FusionReusesParamElements::Compute(i, *fused_expression_root()); default: return IsElementwise() ? UseKind::kUse : UseKind::kReuse; } } -namespace { - -// Prereq: `order` is a permutation of {0, 1, ..., `dims.size()-1`} -void Strip1SizedDimensions(tensorflow::protobuf::RepeatedField* dims, - std::vector* order) { - // We can't merely call StripDegenerateDimensions here as we must also delete - // the dimension indices. - for (size_t i = 0; i < dims->size(); ++i) { - if (1 == dims->Get(i)) { - dims->erase(dims->begin() + i); - // We must find this, as order must be a permutation of operand - // dimensions. - order->erase(std::find(order->begin(), order->end(), i)); - } - } -} - -} // namespace - std::tuple, std::vector> HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const { if (HloOpcode::kReshape != opcode_) { @@ -2033,21 +2260,72 @@ HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const { shape_); } -string FusionKindString(HloInstruction::FusionKind kind) { +string ToString(HloInstruction::FusionKind kind) { switch (kind) { case HloInstruction::FusionKind::kLoop: - return "Loop"; + return "kLoop"; case HloInstruction::FusionKind::kInput: - return "Input"; + return "kInput"; + case HloInstruction::FusionKind::kOutput: + return "kOutput"; case HloInstruction::FusionKind::kTransposeDot: - return "TransposeDot"; + return "kTransposeDot"; case HloInstruction::FusionKind::kConvBackwardFilter: - return "ConvBackwardFilter"; + return "kConvBackwardFilter"; case HloInstruction::FusionKind::kConvBackwardInput: - return "ConvBackwardInput"; + return "kConvBackwardInput"; } } +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { + return os << ToString(kind); +} + +string HloInstruction::ConvolutionDimensionNumbersToString() const { + string result; + if (convolution_dimension_numbers_ == nullptr) { + return result; + } + const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_; + // Show the given dimension labels in order of major to minor based on the + // shape's layout. + const auto append_dims = [&](const std::vector& dims, + const Shape& shape) { + CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); + for (int64 logical = 0; logical < dims.size(); ++logical) { + int64 physical = logical; + if (!shape.layout().minor_to_major().empty()) { + physical = LayoutUtil::Major(shape.layout(), logical); + } + result += dims[physical]; + } + }; + + // lhs_dims[i] is the symbol of the logical dimension i for the lhs + // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". + std::vector lhs_dims(2 + dnums.spatial_dimensions().size()); + lhs_dims[dnums.batch_dimension()] = 'b'; + lhs_dims[dnums.feature_dimension()] = 'f'; + for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { + lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i); + } + + std::vector rhs_dims(2 + dnums.kernel_spatial_dimensions().size()); + rhs_dims[dnums.kernel_input_feature_dimension()] = "i"; + rhs_dims[dnums.kernel_output_feature_dimension()] = "o"; + for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { + rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i); + } + + result += "dim_labels="; + append_dims(lhs_dims, operand(0)->shape()); + result += "_"; + append_dims(rhs_dims, operand(1)->shape()); + result += "->"; + append_dims(lhs_dims, shape()); + return result; +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -2059,4 +2337,15 @@ bool HloInstruction::CouldBeBitcast() const { } } +HloModule* HloInstruction::GetModule() const { + if (parent_) { + return parent_->parent(); + } + return nullptr; +} + +void HloInstruction::UniquifyName(NameUniquer* name_uniquer) { + name_ = name_uniquer->GetUniqueName(name_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index ff52900a2b8..c7cd729934b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -22,16 +22,21 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ #include +#include #include #include -#include #include #include +#include #include +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -44,18 +49,23 @@ limitations under the License. namespace xla { class HloComputation; +class HloModule; // HLO instructions are the IR used by the high-level compiler. class HloInstruction { public: enum class FusionKind { kLoop, // Fused into a loop. - kInput, // Fused into a reduction kernel. + kInput, // Op's input is fused into the op itself. + kOutput, // Op's output is fused into the op itself. + // REQUIRES: At least one operand buffer must be able + // to alias the output buffer. kTransposeDot, // Fused into a dot with transposed operands. kConvBackwardFilter, // Fused into a backward filter convolution. kConvBackwardInput, // Fused into a backward input convolution. }; + ~HloInstruction(); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, const Shape& shape, @@ -137,7 +147,8 @@ class HloInstruction { // Creates an outfeed instruction, which outputs data. static std::unique_ptr CreateOutfeed( - HloInstruction* operand, tensorflow::StringPiece outfeed_config); + const Shape& shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config); // Creates a send instruction with the given channel id, which sends the // operand data to a unique receive instruction in another computation that @@ -156,7 +167,8 @@ class HloInstruction { static std::unique_ptr CreateSlice( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices); + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); // Creates a slice instruction, where the first operand is sliced by // start indices specified in the second operand, and by size specfied in @@ -302,28 +314,38 @@ class HloInstruction { int64 user_count() const { return users_.size(); } // Returns the users of this instruction. - const std::set& users() const { return users_; } + const std::vector& users() const { return users_; } - // Returns the set of control predecessors of this instruction. Control - // predecessors are the instructions that must be scheduled before the current - // instruction. - const std::set& control_predecessors() const { + // Returns true if this instruction is a user of 'instruction'. + bool IsUserOf(const HloInstruction* instruction) const { + return ContainsKey(instruction->user_set_, this); + } + + // Adds a control dependency from this instruction to the given + // instruction. This instruction becomes a control predecessor of + // 'instruction', and 'instruction' becomes a control successor of this + // instruction. Returns an error status if either of the given instructions + // does not belong to the same computation. + // + // This is used to enforce an additional ordering requirement that is not + // captured by normal data dependencies, such as ordering among Send or Recv + // operations to avoid deadlock. + Status AddControlDependencyTo(HloInstruction* instruction); + + // Removes a previously added control dependency from this instruction to + // 'instruction'. + Status RemoveControlDependencyTo(HloInstruction* instruction); + + // Returns the set of control predecessors (successors) of this + // instruction. Control predecessors (sucessors) must execute before (after) + // the current instruction. + const std::vector& control_predecessors() const { return control_predecessors_; } - - // Adds the given instruction to the set of control predecessors. - void AddControlPredecessor(HloInstruction* instruction); - - // Returns the set of control successors of this instruction. Control - // successors are the instructions that must be scheduled after the current - // instruction. - const std::set& control_successors() const { + const std::vector& control_successors() const { return control_successors_; } - // Adds the given instruction to the set of control successors. - void AddControlSuccessor(HloInstruction* instruction); - // Returns true if "other" performs the same computation as this instruction. // Layout of the instructions' output array is not considered. bool Identical( @@ -359,12 +381,25 @@ class HloInstruction { // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when - // complete. - Status Accept(DfsHloVisitor* visitor, bool call_finish_visit = true); + // complete. If ignore_control_predecessors is true, instructions only + // reachable via control dependencies will not be visited, and the postorder + // will not take control dependencies into account. It is as if the control + // dependencies didn't exist in the graph at all. + Status Accept(DfsHloVisitor* visitor, bool call_finish_visit = true, + bool ignore_control_predecessors = false); + + // Same as Accept() above, but the order of operand and control predecessor + // visitation is determined by the given operand order; if compare(A, B) == + // true, A is visited before B. + using CompareFunction = + std::function; + Status AcceptWithOperandOrder(DfsHloVisitor* visitor, + const CompareFunction& operand_order, + bool call_finish_visit = true); // Performs a postorder DFS visit using this node as the root. Calls the given // visitor function at each instruction. - Status Accept(FunctionVisitor::VisitorFunction visitor_func); + Status Accept(const FunctionVisitor::VisitorFunction& visitor_func); // Visits all instructions rooted at this instruction using the given visitor // in the given order. 'order' must contain at least the set of instructions @@ -397,6 +432,11 @@ class HloInstruction { return parameter_name_; } + void set_parameter_name(const string& str) { + CHECK_EQ(HloOpcode::kParameter, opcode_); + parameter_name_ = str; + } + // Returns the dimension sizes or numbers associated with this instruction. // // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, @@ -428,6 +468,10 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kOutfeed const string& outfeed_config() const; + // Returns the shape for the Outfeed instruction. + // Precondition: opcode() == HloOpcode::kOutfeed + const Shape& outfeed_shape() const; + // Gets/sets the while_condition or while_body HloComputation for While. The // setters should only be called by HloModule or HloComputation methods. // @@ -451,15 +495,26 @@ class HloInstruction { string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false) const; + string ToString(bool compact_operands = false, + bool include_metadata = true) const; + + string ToStringNoMetadata() const { return ToString(false, false); } // As ToString, but returns a shorter string. string ToShortString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const; + // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". string ToCategory() const; + // Returns the string concatenation of parent name and this instructions + // name. This name is guaranteed to be unique among all instructions in the + // HloModule. + string FullyQualifiedName() const; + // Returns a logging instruction, if the output of this instruction is logged. // // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace @@ -482,7 +537,7 @@ class HloInstruction { // Returns a tag to be used in tracing. // // Precondition: opcode() == HloOpcode::kTrace - const string& tracing_tag() const; + string TracingTag() const; // Returns whether the instruction is a constant. bool IsConstant() const; @@ -506,10 +561,18 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kFusion HloInstruction* fused_expression_root() const; + // Returns the computation for this fused instruction. + // + // Precondition: opcode() == HloOpcode::kFusion + HloComputation* fused_instructions_computation() const; + // Returns the vector of fused instructions inside this fusion // instruction. The order is a reverse postorder of the fused expression (root // is first in the order). // + // Note: although the list itself is const, the instructions contained in the + // list returned here are mutable. + // // Precondition: opcode() == HloOpcode::kFusion const std::list>& fused_instructions() const; @@ -519,6 +582,18 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kFusion HloInstruction* fused_parameter(int64 parameter_number) const; + // Returns the vector of fused parameters inside this fusion instruction. + // + // Precondition: opcode() == HloOpcode::kFusion + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const { + return (opcode() == HloOpcode::kFusion && + fused_expression_root()->opcode() == HloOpcode::kTuple); + } + FusionKind fusion_kind() const { CHECK_EQ(HloOpcode::kFusion, opcode_); return fusion_kind_; @@ -564,6 +639,15 @@ class HloInstruction { return slice_limits_; } + // Returns the stride in the given dimension for a slice node. + // + // Precondition: opcode() == HloOpcode::kSlice + int64 slice_stride(int64 dimension) const { + CHECK_EQ(HloOpcode::kSlice, opcode_); + return slice_strides_[dimension]; + } + const std::vector& slice_strides() const { return slice_strides_; } + // Returns the size of the slice in the given dimension for a dynamic // slice node. // @@ -599,6 +683,9 @@ class HloInstruction { return *convolution_dimension_numbers_; } + // Returns the dump string of the convolution dimension numbers. + string ConvolutionDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -606,18 +693,20 @@ class HloInstruction { // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction - // cloned from) is not changed. - std::unique_ptr Clone(); + // cloned from) is not changed. Suffix is the string to append to the name of + // the instruction to form the name of the cloned instruction. + std::unique_ptr Clone(const string& suffix = "clone"); // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands); - // Computes and returns the computations this instruction calls (if any). This - // includes computations called by fused instructions inside of a fusion - // instruction. - std::set MakeCalledComputationsSet() const; + // Returns the computations this instruction calls (if any). This includes + // computations called by fused instructions inside of a fusion instruction. + const std::vector& called_computations() const { + return called_computations_; + } // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, @@ -653,13 +742,23 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; + // Returns the opcode string for this instruction. Compared with + // HloOpcodeString method, this wrapper dumps additional information + // such as fusion kind. + string ExtendedOpcodeStr() const; + // Returns a string identifier for this instruction. If no string identifier // has been explicitly set, then the identifier is the serialized pointer to // this instruction. const string& name() const { return name_; } - // Sets the string identifier for this instruction. - void set_name(const string& name) { name_ = name; } + // Use the given NameUniquer to select a unique name for the instruction based + // on the instruction's existing name. + void UniquifyName(NameUniquer* name_uniquer); + + // Sets the debug metadata for this instruction. + void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } + const OpMetadata& metadata() const { return metadata_; } // Set/get the computation containing this instruction. set_parent should only // be called by HloComputation methods which add/remove instructions to @@ -668,13 +767,27 @@ class HloInstruction { const HloComputation* parent() const { return parent_; } HloComputation* parent() { return parent_; } + // Returns the module for this instruction. + HloModule* GetModule() const; + // Returns whether we could assign input and output layouts to this // instruction to make it a bitcast. bool CouldBeBitcast() const; + // Sets the parent fusion instruction for this instruction. + // + // Precondition: opcode() == HloOpcode::kFusion + void SetParentFusion(HloInstruction* fusion_instruction) { + CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); + parent_fusion_instruction_ = fusion_instruction; + } + private: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; + // Helper class for computing OperandElementUse for kFusion. + class FusionReusesParamElements; + // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, @@ -707,7 +820,9 @@ class HloInstruction { // Inner DFS traversal function -- this function being called (rather than // Accept above) allows us to distinguish the root of the traversal. - Status AcceptInternal(DfsHloVisitor* visitor); + Status AcceptInternal(DfsHloVisitor* visitor, + const CompareFunction* operand_order, + bool ignore_control_predecessors); // CHECKs various invariants of a fusion instruction. void CheckFusionInstruction() const; @@ -719,6 +834,9 @@ class HloInstruction { // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; + // Shape of outfeed request. + Shape outfeed_shape_; + // Result shape of this instruction. Shape shape_; @@ -744,6 +862,7 @@ class HloInstruction { // Describes the [begin, end) index range for a slice. std::vector slice_starts_; std::vector slice_limits_; + std::vector slice_strides_; // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). @@ -753,22 +872,14 @@ class HloInstruction { // padding of this pad instruction. Only set for pad instructions. std::unique_ptr padding_config_; - // The set of instruction fused into this fusion instruction. Only set for - // fusion instructions. - std::list> fused_instructions_; + // The computation that stores of instructions fused into this fusion + // instruction. Only set for fusion instructions. + std::unique_ptr fused_instructions_computation_; // If this instruction is fused into a fusion instruction, this field points // to the fusion instruction. HloInstruction* parent_fusion_instruction_ = nullptr; - // The vector of parameter instructions inside this fusion instruction. The - // index of the vector is the parameter_number of the parameter instruction. - // This vector is non-empty only for fusion instructions. - std::vector fused_parameters_; - - // The root of the expression fused into this fusion instruction. - HloInstruction* fused_root_ = nullptr; - // The type of the fusion. Used by kFusion only. FusionKind fusion_kind_; @@ -776,21 +887,23 @@ class HloInstruction { int64 parameter_number_ = 0; string parameter_name_; - // Computation to apply, only present for kCall, kMap, kReduce and - // kReduceWindow. - HloComputation* to_apply_ = nullptr; - // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; - // Computation for condition and body of kWhile, only present for kWhile. - HloComputation* condition_ = nullptr; - HloComputation* body_ = nullptr; + // Computations called by this instruction. + std::vector called_computations_; - // Computation for select and scatter, only present for - // kSelectAndScatter. - HloComputation* select_ = nullptr; - HloComputation* scatter_ = nullptr; + // Indices of computations in called_computations_ for instructions which call + // multiple computations. + enum { + // kWhile computations. + kBodyComputationIndex = 0, + kConditionComputationIndex = 1, + + // kSelectAndScatter computations. + kSelectComputationIndex = 0, + kScatterComputationIndex = 1, + }; // Outfeed configuration information, only present for kOutfeed. string outfeed_config_; @@ -799,14 +912,17 @@ class HloInstruction { std::vector operands_; // The users of this instruction. Users are HLOs where this instruction is an - // operand. - std::set users_; + // operand. The vector users_ and the set user_set_ contain identical + // members. The set enables fast membership testing and the vector enables + // fast, stable iteration. + std::vector users_; + std::unordered_set user_set_; // The set of control predecessors of this instruction. - std::set control_predecessors_; + std::vector control_predecessors_; // The set of control successors of this instruction. - std::set control_successors_; + std::vector control_successors_; // A trace instruction that consumes this instruction. // @@ -831,10 +947,15 @@ class HloInstruction { // The computation in which this instruction is contained. HloComputation* parent_ = nullptr; + // Metadata for debugging. + OpMetadata metadata_; + TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); }; -string FusionKindString(HloInstruction::FusionKind kind); +string ToString(HloInstruction::FusionKind kind); + +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 48711b605f2..bcf81cd8ddf 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -21,19 +21,22 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace { -#define EXPECT_ISET(A, E...) EXPECT_EQ(A, (std::set{E})) -#define EXPECT_IVEC(A, E...) EXPECT_EQ(A, (std::vector{E})) +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; -class HloInstructionTest : public ::testing::Test { +class HloInstructionTest : public HloTestBase { protected: HloInstructionTest() {} @@ -149,10 +152,10 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar"); auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(), bar.get()); - EXPECT_MATCH(add->operands(), testing::UnorderedMatcher( - foo.get(), bar.get())); - EXPECT_ISET(foo->users(), add.get()); - EXPECT_ISET(bar->users(), add.get()); + + EXPECT_THAT(add->operands(), UnorderedElementsAre(foo.get(), bar.get())); + EXPECT_THAT(foo->users(), UnorderedElementsAre(add.get())); + EXPECT_THAT(bar->users(), UnorderedElementsAre(add.get())); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(add->Accept(&visitor)); @@ -385,12 +388,13 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { EXPECT_EQ(1, foo->user_count()); EXPECT_EQ(2, bar->user_count()); - EXPECT_ISET(foo->users(), add_foobar.get()); - EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get()); + EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar.get())); + EXPECT_THAT(add_foobar->operands(), ElementsAre(foo.get(), bar.get())); - EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get()); - EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get()); - EXPECT_IVEC(add_foofoo->operands(), bar.get(), bar.get()); + EXPECT_THAT(bar->users(), + UnorderedElementsAre(add_foobar.get(), add_foofoo.get())); + EXPECT_THAT(add_foobar->operands(), ElementsAre(foo.get(), bar.get())); + EXPECT_THAT(add_foofoo->operands(), ElementsAre(bar.get(), bar.get())); } TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { @@ -406,15 +410,17 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { foo.get(), bar.get()); EXPECT_EQ(2, foo->user_count()); - EXPECT_ISET(foo->users(), tuple.get(), add_foobar.get()); + EXPECT_THAT(foo->users(), + UnorderedElementsAre(tuple.get(), add_foobar.get())); // Replace the use of foo in tuple with bar. ASSERT_IS_OK(foo->ReplaceUseWith(tuple.get(), bar.get())); - EXPECT_ISET(foo->users(), add_foobar.get()); + EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar.get())); // Both uses of foo in tuple should have been replaced with bar. - EXPECT_IVEC(tuple->operands(), bar.get(), bar.get(), baz.get(), bar.get()); + EXPECT_THAT(tuple->operands(), + ElementsAre(bar.get(), bar.get(), baz.get(), bar.get())); } TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { @@ -427,7 +433,7 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { auto log = HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo.get()); EXPECT_EQ(2, foo->user_count()); - EXPECT_ISET(foo->users(), exp.get(), log.get()); + EXPECT_THAT(foo->users(), UnorderedElementsAre(exp.get(), log.get())); EXPECT_EQ(0, bar->user_count()); // Replace the use of foo in exp with bar. @@ -435,8 +441,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { // The use of foo in log should not have been affected. EXPECT_EQ(1, foo->user_count()); - EXPECT_ISET(foo->users(), log.get()); - EXPECT_IVEC(log->operands(), foo.get()); + EXPECT_THAT(foo->users(), UnorderedElementsAre(log.get())); + EXPECT_THAT(log->operands(), ElementsAre(foo.get())); // Bar should now be used in exp. EXPECT_EQ(1, bar->user_count()); @@ -467,7 +473,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { EXPECT_EQ(0, foo->user_count()); EXPECT_EQ(2, bar->user_count()); - EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get()); + EXPECT_THAT(bar->users(), + UnorderedElementsAre(add_foobar.get(), add_foofoo.get())); } TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { @@ -491,7 +498,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { EXPECT_EQ(0, foo->user_count()); EXPECT_EQ(3, bar->user_count()); - EXPECT_ISET(bar->users(), add_foobar.get(), exp.get(), tuple.get()); + EXPECT_THAT(bar->users(), + UnorderedElementsAre(add_foobar.get(), exp.get(), tuple.get())); } // Simple visitor that collects and post-processes each node in the graph. @@ -559,8 +567,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { auto fusion = HloInstruction::CreateFusion( r0f32_, HloInstruction::FusionKind::kLoop, exp.get()); - EXPECT_IVEC(fusion->operands(), constant.get()); - EXPECT_ISET(constant->users(), fusion.get(), exp.get()); + EXPECT_THAT(fusion->operands(), ElementsAre(constant.get())); + EXPECT_THAT(constant->users(), UnorderedElementsAre(fusion.get(), exp.get())); } TEST_F(HloInstructionTest, BinaryFusionOp) { @@ -575,9 +583,12 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { auto fusion = HloInstruction::CreateFusion( r0f32_, HloInstruction::FusionKind::kLoop, add.get()); - EXPECT_IVEC(fusion->operands(), constant1.get(), constant2.get()); - EXPECT_ISET(constant1->users(), fusion.get(), add.get()); - EXPECT_ISET(constant2->users(), fusion.get(), add.get()); + EXPECT_THAT(fusion->operands(), + ElementsAre(constant1.get(), constant2.get())); + EXPECT_THAT(constant1->users(), + UnorderedElementsAre(fusion.get(), add.get())); + EXPECT_THAT(constant2->users(), + UnorderedElementsAre(fusion.get(), add.get())); } TEST_F(HloInstructionTest, ChainFusionOp) { @@ -594,8 +605,68 @@ TEST_F(HloInstructionTest, ChainFusionOp) { fusion->FuseInstruction(exp2.get()); fusion->FuseInstruction(exp1.get()); - EXPECT_IVEC(fusion->operands(), constant.get()); - EXPECT_ISET(constant->users(), fusion.get(), exp1.get()); + EXPECT_THAT(fusion->operands(), ElementsAre(constant.get())); + EXPECT_THAT(constant->users(), + UnorderedElementsAre(fusion.get(), exp1.get())); +} + +TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { + // Create a chain of fused unary ops. + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto exp1 = + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); + auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); + OpMetadata metadata; + metadata.set_op_name("tf_op"); + exp1->set_metadata(metadata); + exp2->set_metadata(metadata); + + auto fusion = HloInstruction::CreateFusion( + r0f32_, HloInstruction::FusionKind::kLoop, exp2.get()); + auto* fused = fusion->FuseInstruction(exp1.get()); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata())); +} + +TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { + // Create a fusion instruction containing a single unary operation. + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + + auto make_map_computation = [&]() { + auto builder = HloComputation::Builder("FusionMap"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + return builder.Build(); + }; + + std::unique_ptr computation_x = make_map_computation(); + std::unique_ptr computation_y = make_map_computation(); + + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto map_1_x = + HloInstruction::CreateMap(scalar_shape, {constant.get()}, + computation_x.get(), /*static_operands=*/{}); + auto map_2_x = + HloInstruction::CreateMap(scalar_shape, {map_1_x.get()}, + computation_x.get(), /*static_operands=*/{}); + auto map_3_y = + HloInstruction::CreateMap(scalar_shape, {map_2_x.get()}, + computation_y.get(), /*static_operands=*/{}); + + auto fusion = HloInstruction::CreateFusion( + scalar_shape, HloInstruction::FusionKind::kLoop, map_3_y.get()); + + EXPECT_THAT(fusion->called_computations(), ElementsAre(computation_y.get())); + + fusion->FuseInstruction(map_2_x.get()); + EXPECT_THAT(fusion->called_computations(), + ElementsAre(computation_y.get(), computation_x.get())); + + fusion->FuseInstruction(map_1_x.get()); + EXPECT_THAT(fusion->called_computations(), + ElementsAre(computation_y.get(), computation_x.get())); } TEST_F(HloInstructionTest, ComplexFusionOp) { @@ -636,8 +707,9 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // Operands in the fusion instruction's operands() vector should be in the // order in which their users were added fused. - EXPECT_IVEC(fusion->operands(), c1.get(), c3.get(), c2.get()); - EXPECT_ISET(c1->users(), add.get(), tuple.get(), fusion.get()); + EXPECT_THAT(fusion->operands(), ElementsAre(c1.get(), c3.get(), c2.get())); + EXPECT_THAT(c1->users(), + UnorderedElementsAre(add.get(), tuple.get(), fusion.get())); } // Convenience function for comparing two HloInstructions inside of @@ -890,5 +962,48 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { root2->operand(1)->operand(0)->shape())); } +TEST_F(HloInstructionTest, CloneSuffixNames) { + // Test that the suffix string added to cloned instructions is not + // duplicated. Rather a numeric incrementing value should be appended. That + // is, we want "foo.clone2", not "foo.clone.clone". + + // Test cloning the same instruction multiple times. + auto foo = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo"); + EXPECT_EQ(foo->Clone()->name(), "%foo.clone"); + EXPECT_EQ(foo->Clone()->Clone()->name(), "%foo.clone2"); + EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "%foo.clone3"); + + // Test custom suffixes. + EXPECT_EQ(foo->Clone("bar")->name(), "%foo.bar"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "%foo.bar2"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), + "%foo.bar2.clone"); + + // Test instruction name with a dot. + auto foo_baz = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.baz"); + EXPECT_EQ(foo_baz->Clone()->name(), "%foo.baz.clone"); + + // Test incrementing a large number after the suffix. + auto foo_clone234 = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.clone234"); + EXPECT_EQ(foo_clone234->Clone()->name(), "%foo.clone235"); + + // Test a non-numeric string after the cloning suffix. + auto foo_clonexyz = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz"); + EXPECT_EQ(foo_clonexyz->Clone()->name(), "%foo.clonexyz.clone"); + + // Test a name with multiple appearances of the suffix. + auto foo_clone_clone3 = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3"); + EXPECT_EQ(foo_clone_clone3->Clone()->name(), "%foo.clone.clone4"); +} + } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc new file mode 100644 index 00000000000..e022c4836d8 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace testing { + +bool HloMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + // These cases are self-explanatory from the printed value. + if (!instruction || instruction->opcode() != opcode_) { + return false; + } + // Special case: no operand matchers means don't verify. + if (operands_.empty()) { + return true; + } + const auto& operands = instruction->operands(); + if (operands.size() != operands_.size()) { + *listener << "has too " + << (operands.size() > operands_.size() ? "many" : "few") + << " operands (got " << operands.size() << ", want " + << operands_.size() << ")"; + return false; + } + for (int index = 0; index < operands.size(); index++) { + ::testing::StringMatchResultListener inner_listener; + if (!operands_[index].MatchAndExplain(operands[index], &inner_listener)) { + if (listener->IsInterested()) { + *listener << "\noperand " << index << ":\n\t" + << operands[index]->ToString() + << "\ndoesn't match expected:\n\t"; + operands_[index].DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << ", " << explanation; + } + } + return false; + } + } + return true; +} + +void HloMatcher::DescribeTo(::std::ostream* os) const { + *os << opcode_; + if (!operands_.empty()) { + *os << "("; + for (int i = 0; i < operands_.size(); i++) { + if (i > 0) { + *os << ", "; + } + operands_[i].DescribeTo(os); + } + *os << ")"; + } +} + +} // namespace testing +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h new file mode 100644 index 00000000000..141251011cc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace testing { + +class HloMatcher : public ::testing::MatcherInterface { + public: + HloMatcher(HloOpcode opcode, + std::vector<::testing::Matcher> operands) + : opcode_(opcode), operands_(operands) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + void DescribeTo(::std::ostream* os) const override; + + private: + HloOpcode opcode_; + std::vector<::testing::Matcher> operands_; +}; + +// HloInstruction* matchers for opcode and operands. Example: +// namespace op = xla::opcode_matchers; +// EXPECT_THAT(instruction, +// op::Add(op::Reshape(), op::Add(op::Reshape(), _))); +namespace opcode_matchers { +#define HLO_MATCHER(opcode) \ + template \ + ::testing::Matcher opcode(M... operands) { \ + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( \ + ::xla::HloOpcode::k##opcode, {operands...})); \ + } +HLO_MATCHER(Abs); +HLO_MATCHER(Add); +HLO_MATCHER(Bitcast); +HLO_MATCHER(Broadcast); +HLO_MATCHER(Call); +HLO_MATCHER(Ceil); +HLO_MATCHER(Clamp); +HLO_MATCHER(Concatenate); +HLO_MATCHER(Constant); +HLO_MATCHER(Convert); +HLO_MATCHER(Convolution); +HLO_MATCHER(Copy); +HLO_MATCHER(CrossReplicaSum); +HLO_MATCHER(CustomCall); +HLO_MATCHER(Divide); +HLO_MATCHER(Dot); +HLO_MATCHER(DynamicSlice); +HLO_MATCHER(DynamicUpdateSlice); +HLO_MATCHER(Eq); +HLO_MATCHER(Exp); +HLO_MATCHER(Floor); +HLO_MATCHER(Fusion); +HLO_MATCHER(Ge); +HLO_MATCHER(GetTupleElement); +HLO_MATCHER(Gt); +HLO_MATCHER(Index); +HLO_MATCHER(Infeed); +HLO_MATCHER(IsFinite); +HLO_MATCHER(Le); +HLO_MATCHER(Log); +HLO_MATCHER(LogicalAnd); +HLO_MATCHER(LogicalNot); +HLO_MATCHER(LogicalOr); +HLO_MATCHER(Lt); +HLO_MATCHER(Map); +HLO_MATCHER(Maximum); +HLO_MATCHER(Minimum); +HLO_MATCHER(Multiply); +HLO_MATCHER(Ne); +HLO_MATCHER(Negate); +HLO_MATCHER(Outfeed); +HLO_MATCHER(Pad); +HLO_MATCHER(Parameter); +HLO_MATCHER(Power); +HLO_MATCHER(Recv); +HLO_MATCHER(Reduce); +HLO_MATCHER(ReduceWindow); +HLO_MATCHER(Remainder); +HLO_MATCHER(Reshape); +HLO_MATCHER(Reverse); +HLO_MATCHER(Rng); +HLO_MATCHER(Select); +HLO_MATCHER(SelectAndScatter); +HLO_MATCHER(Send); +HLO_MATCHER(Sign); +HLO_MATCHER(Slice); +HLO_MATCHER(Sort); +HLO_MATCHER(Subtract); +HLO_MATCHER(Tanh); +HLO_MATCHER(Trace); +HLO_MATCHER(Transpose); +HLO_MATCHER(Tuple); +HLO_MATCHER(Update); +HLO_MATCHER(While); +#undef HLO_MATCHER +} // namespace opcode_matchers + +// Helper to convert smart to raw pointers for matching. +template +std::vector Pointers(const Container& container) { + std::vector result; + result.reserve(container.size()); + for (const auto& entry : container) result.push_back(entry.get()); + return result; +} + +} // namespace testing + +// Tell GMock to print HloInstruction* by value, so error messages are nice. +// Has to be in the same namespace as 'HloInstruction'. +void PrintTo(const HloInstruction* inst, ::std::ostream* os) { + *os << (inst ? inst->ToString() : "nullptr"); +} + +void PrintTo(HloInstruction* inst, ::std::ostream* os) { + PrintTo(const_cast(inst), os); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc new file mode 100644 index 00000000000..1465d1cacdc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; +using ::testing::Eq; + +namespace xla { +namespace { + +template +string Explain(const T& t, const M& m) { + ::testing::StringMatchResultListener listener; + EXPECT_THAT(t, ::testing::Not(m)); // For the error message. + EXPECT_FALSE(m.MatchAndExplain(t, &listener)); + return listener.str(); +} + +TEST(HloMatchersTest, Test) { + auto shape = ShapeUtil::MakeShape(F32, {1}); + auto param = HloInstruction::CreateParameter(0, shape, "param"); + auto mul = HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, + param.get(), param.get()); + auto add = HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param.get(), + mul.get()); + + EXPECT_THAT(add.get(), op::Add()); + EXPECT_THAT(add.get(), op::Add(op::Parameter(), op::Multiply())); + EXPECT_THAT(add.get(), + op::Add(op::Parameter(), op::Multiply(_, op::Parameter()))); + + // Negative matches: check the explanation string. + EXPECT_THAT(Explain(add.get(), op::Parameter()), Eq("")); + EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter())), + Eq("has too many operands (got 2, want 1)")); + EXPECT_THAT( + Explain(add.get(), op::Add(op::Parameter(), op::Parameter())), + Eq("\noperand 1:\n\t" + "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" + "doesn't match expected:\n\t" + "parameter")); + EXPECT_THAT( + Explain(add.get(), + op::Add(op::Parameter(), op::Multiply(op::Add(), op::Add()))), + Eq("\noperand 1:\n\t" + "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" + "doesn't match expected:\n\t" + "multiply(add, add), \n" + "operand 0:\n\t" + "%param = f32[1]{0} parameter(0)\n" + "doesn't match expected:\n\t" + "add")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 5d68b456cda..22ef9c590bc 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -31,20 +32,46 @@ limitations under the License. namespace xla { -HloComputation* HloModule::AddEntryComputation( +HloModule::HloModule(const string& name, + const VersionedComputationHandle& entry_computation_handle, + const HloModuleConfig& config) + : name_(name), + config_(config), + entry_computation_(nullptr), + has_entry_computation_handle_(true), + entry_computation_handle_(entry_computation_handle), + computation_name_uniquer_(/*separator=*/".") {} + +HloModule::HloModule(const string& name) + : name_(name), + entry_computation_(nullptr), + computation_name_uniquer_(/*separator=*/".") {} + +HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation) { - CHECK_EQ(nullptr, entry_computation_); - entry_computation_ = computation.get(); + computation->UniquifyName(&computation_name_uniquer_); computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); } +HloComputation* HloModule::AddEntryComputation( + std::unique_ptr computation) { + CHECK_EQ(nullptr, entry_computation_); + entry_computation_ = computation.get(); + + // If the module configuration has no entry layout computation set, create a + // default one based on the program shape. + if (!config_.has_entry_computation_layout()) { + config_.SetDefaultComputationLayout( + entry_computation_->ComputeProgramShape()); + } + return AddComputationInternal(std::move(computation)); +} + HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { - computation->set_parent(this); - computations_.push_back(std::move(computation)); - return computations_.back().get(); + return AddComputationInternal(std::move(computation)); } void HloModule::ReplaceComputations( @@ -123,6 +150,17 @@ string HloModule::ToString() const { return s.str(); } +HloModuleProto HloModule::ToProto() const { + HloModuleProto proto; + proto.set_name(name_); + proto.set_entry_computation_name(entry_computation_->name()); + for (const HloComputation* computation : MakeComputationPostOrder()) { + HloComputationProto computation_proto = computation->ToProto(); + proto.add_computations()->Swap(&computation_proto); + } + return proto; +} + namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given @@ -232,7 +270,8 @@ std::list HloModule::MakeComputationPostOrder() const { std::set nonroot_computations; for (auto& computation : computations_) { for (auto& instruction : computation->instructions()) { - for (auto called_computation : instruction->MakeCalledComputationsSet()) { + for (HloComputation* called_computation : + instruction->called_computations()) { nonroot_computations.insert(called_computation); } } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index d598750da65..4b14b4fd62a 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -23,8 +23,11 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -41,19 +44,15 @@ namespace xla { // computations are owned by the module. class HloModule { public: - explicit HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle) - : name_(name), - entry_computation_(nullptr), - has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle) {} + HloModule(const string& name, + const VersionedComputationHandle& entry_computation_handle, + const HloModuleConfig& config); // Constructor without a versioned computation handle. This constructor should // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation - // cache. - explicit HloModule(const string& name) - : name_(name), entry_computation_(nullptr) {} + // cache. A default configuration is created for this module. + explicit HloModule(const string& name); // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. @@ -82,6 +81,10 @@ class HloModule { return entry_computation_; } + ComputationLayout* mutable_entry_computation_layout() { + return config_.mutable_entry_computation_layout(); + } + const VersionedComputationHandle& entry_computation_handle() const { return entry_computation_handle_; } @@ -95,7 +98,10 @@ class HloModule { // computation B, then A will appear after B in the sort. std::list MakeComputationPostOrder() const; + const HloModuleConfig& config() const { return config_; } + string ToString() const; + HloModuleProto ToProto() const; // Outlines the given expression from the given computation. // instructions_to_outline contains the instructions that form the expression. @@ -110,8 +116,17 @@ class HloModule { // Returns a randomly generated uint64. uint64 RandomNew64() const; + // Returns the unique name for a computation in this module. + string GetUniqueCompuationName(const string& prefix) { + return computation_name_uniquer_.GetUniqueName(prefix); + } + private: + HloComputation* AddComputationInternal( + std::unique_ptr computation); + const string name_; + HloModuleConfig config_; HloComputation* entry_computation_; std::vector> computations_; @@ -125,6 +140,9 @@ class HloModule { // Versioned handle of the entry computation of the module. bool has_entry_computation_handle_ = false; VersionedComputationHandle entry_computation_handle_; + + // Unique name generator for computation names, which are unique per module. + NameUniquer computation_name_uniquer_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index c129ad1b392..a2235a26823 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -27,20 +28,27 @@ namespace xla { using tensorflow::strings::StrAppend; +HloModuleConfig::HloModuleConfig() {} + HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) : entry_computation_layout_(program_shape) {} +void HloModuleConfig::SetDefaultComputationLayout( + const ProgramShape& program_shape) { + entry_computation_layout_ = ComputationLayout(program_shape); +} + string HloModuleConfig::compilation_cache_key() const { string key = tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_, "::hybrid=", has_hybrid_result_); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - entry_computation_layout_.parameter_layouts()) { + entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - entry_computation_layout_.result_shape().SerializeAsString()); + entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; @@ -49,7 +57,7 @@ string HloModuleConfig::compilation_cache_key() const { if (replica_count() != 1) { StrAppend(&key, "::replica_count=", replica_count()); } - StrAppend(&key, "::fast_math_disabled=", fast_math_disabled_); + StrAppend(&key, debug_options_.DebugString()); return key; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index f9a61c1cd1c..ee32ab9bc4b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -32,14 +33,34 @@ namespace xla { // executable. class HloModuleConfig { public: + // A configuration can be created either with, or without an entry + // ComputationLayout. The default ctor creates it without -- in this case + // accessing entry_computation_layout will CHECK-fail. The ctor accepting a + // ProgramShape creates a computation layout using this shape. + HloModuleConfig(); explicit HloModuleConfig(const ProgramShape& program_shape); - // Return a reference to the layout of the entry computation. - const ComputationLayout& entry_computation_layout() const { - return entry_computation_layout_; + // Checks if this config has an entry computation layout already. + bool has_entry_computation_layout() const { + return entry_computation_layout_.has_value(); } + + // Sets the entry computation layout for this config. If the entry computation + // layout already exists, it is silently replaced. + void SetDefaultComputationLayout(const ProgramShape& program_shape); + + // Returns a constant reference to the layout of the entry computation. + // Assumes the layout was set. + const ComputationLayout& entry_computation_layout() const { + CHECK(entry_computation_layout_.has_value()); + return *entry_computation_layout_; + } + + // Returns a mutable pointer to the layout of the entry computation. Assumes + // the layout was set. ComputationLayout* mutable_entry_computation_layout() { - return &entry_computation_layout_; + CHECK(entry_computation_layout_.has_value()); + return &(*entry_computation_layout_); } // Sets/returns whether to enable HLO-level profiling. @@ -60,23 +81,21 @@ class HloModuleConfig { } int64 replica_count() const { return replica_count_; } - // Sets/returns whether unsafe math optimizations are disabled for this - // module. Default is fast-math enabled. - // - // This is named fast_math_disabled rather than the more natural - // fast_math_enabled for consistency with the ExecutionOptions proto. - bool fast_math_disabled() const { return fast_math_disabled_; } - void set_fast_math_disabled(bool disabled) { fast_math_disabled_ = disabled; } - // Return a string which unambiguously represents all the fields of this data // structure. Used for generating a cache key for storing the compiled // executable. string compilation_cache_key() const; + const DebugOptions& debug_options() const { return debug_options_; } + + void set_debug_options(const DebugOptions& debug_options) { + debug_options_ = debug_options; + } + private: // If you add new members, be sure to update compilation_cache_key. - ComputationLayout entry_computation_layout_; + tensorflow::gtl::optional entry_computation_layout_; // Whether to enable HLO-level profiling. bool hlo_profiling_enabled_ = false; @@ -97,7 +116,7 @@ class HloModuleConfig { // The number of replicas to compile this binary for. int64 replica_count_ = 1; - bool fast_math_disabled_ = false; + DebugOptions debug_options_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 0f4252522d3..870bc729aec 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -58,27 +58,32 @@ class HloModuleTest : public HloTestBase { TEST_F(HloModuleTest, OneComputationPostOrder) { // Create a module with a single computation. - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(CreateConstantComputation()); - EXPECT_EQ(module->MakeComputationPostOrder().front(), computation); + EXPECT_THAT(module->MakeComputationPostOrder(), + ::testing::ElementsAre(computation)); } TEST_F(HloModuleTest, TwoComputationsPostOrder) { // Create a module with two unconnected computations. - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation1 = module->AddEntryComputation(CreateConstantComputation()); auto computation2 = module->AddEmbeddedComputation(CreateConstantComputation()); - EXPECT_MATCH( - testing::ListToVec(module->MakeComputationPostOrder()), - testing::UnorderedMatcher(computation1, computation2)); + EXPECT_THAT(module->MakeComputationPostOrder(), + ::testing::UnorderedElementsAre(computation1, computation2)); + + // We specified the same name for both computations, but the HloModule should + // have made the names unique. + EXPECT_EQ(computation1->name(), "Constant"); + EXPECT_EQ(computation2->name(), "Constant.1"); } TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -89,9 +94,9 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { CreateCallComputation({computation2, computation3})); auto post_order = module->MakeComputationPostOrder(); - EXPECT_MATCH(testing::ListToVec(post_order), - testing::UnorderedMatcher( - computation1, computation2, computation3, computation4)); + EXPECT_THAT(post_order, + ::testing::UnorderedElementsAre(computation1, computation2, + computation3, computation4)); EXPECT_EQ(post_order.back(), computation4); EXPECT_EQ(post_order.front(), computation1); } @@ -99,3 +104,7 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 5f7243b0fe7..ceb0cdaa316 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -74,6 +74,8 @@ string HloOpcodeString(HloOpcode opcode) { return "index"; case HloOpcode::kInfeed: return "infeed"; + case HloOpcode::kIsFinite: + return "is-finite"; case HloOpcode::kLe: return "less-than-or-equal-to"; case HloOpcode::kLog: @@ -163,4 +165,17 @@ bool HloOpcodeIsComparison(HloOpcode opcode) { } } +bool HloOpcodeIsVariadic(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kCall: + case HloOpcode::kConcatenate: + case HloOpcode::kFusion: + case HloOpcode::kMap: + case HloOpcode::kTuple: + return true; + default: + return false; + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 5d60a77e14f..e2cdbfdfa7a 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -55,6 +55,7 @@ enum class HloOpcode { kGt, kIndex, kInfeed, + kIsFinite, kLe, kLog, kLogicalAnd, @@ -103,6 +104,9 @@ inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { // Returns true iff the given opcode is a comparison operation. bool HloOpcodeIsComparison(HloOpcode opcode); +// Returns true iff the given opcode has variadic operands. +bool HloOpcodeIsVariadic(HloOpcode opcode); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 0b64c16fdc6..892c89f9df2 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 38106dbbb11..72911ae9f91 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -33,15 +34,112 @@ limitations under the License. namespace xla { -PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) - : module_(module) {} +namespace { -bool PredecessorHloOrdering::ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const { - // Instructions in different computations are unordered. - if (a->parent() != b->parent()) { +// Returns the nearest call graph ancestors of instructions 'a' and 'b' for +// which the ancestors are in the same computation. An instruction is an call +// graph ancestor of 'a' if the instruction calls the computation containing 'a' +// either directly or transitively. Degeneratively an instruction is an ancestor +// of itself. nullptr is returned if there is no common ancestor or if the +// caller chain of 'a' or 'b' diverges (has multiple callers) before the nearest +// common ancestor. +// +// Example: +// +// Entry computation: +// %x = Call(A, {Constant(42.0)}) +// %y = Call(B, {%x}) +// +// Computation A: +// %a = Negate(Param()) +// +// Computation B: +// %b = Exp(Param()); +// +// If called with %a and %b, this function would return (%x, %y). %x is an +// ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same +// computation. +std::pair +GetNearestCallGraphAncestorsInSameComputation(const HloInstruction* a, + const HloInstruction* b, + const CallGraph& call_graph) { + // Lambda which returns the next instruction in the callee->caller chain in + // the call graph. This is the unique instruction which calls the computation + // containing 'instruction'. If more than one instruction calls the + // computation containing 'instruction' or no instructions call the + // computation then nullptr is returned. + auto next_caller = + [&call_graph]( + const HloInstruction* instruction) -> const HloInstruction* { + const CallGraphNode& node = call_graph.GetNode(instruction->parent()); + if (node.caller_callsites().size() != 1) { + return nullptr; + } + return node.caller_callsites()[0].instruction(); + }; + + // Iterate through the callee->caller chains and find the earliest common + // element. + for (const HloInstruction* a_ancestor = a; a_ancestor != nullptr; + a_ancestor = next_caller(a_ancestor)) { + for (const HloInstruction* b_ancestor = b; b_ancestor != nullptr; + b_ancestor = next_caller(b_ancestor)) { + if (a_ancestor->parent() == b_ancestor->parent()) { + return {a_ancestor, b_ancestor}; + } + } + } + return {nullptr, nullptr}; +} + +} // namespace + +bool HloOrdering::ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const { + // 'a' and 'b' may be in different computations. In this case, find the + // callgraph ancestor instructions which call (potentially transitively) the + // computations containing 'a' and 'b' and use these ancestor instructions to + // compare order. + const HloInstruction* a_ancestor; + const HloInstruction* b_ancestor; + std::tie(a_ancestor, b_ancestor) = + GetNearestCallGraphAncestorsInSameComputation(a, b, *call_graph_); + + if (a_ancestor == nullptr) { + // Ancestors in a common computation could not be found so consider the + // instructions 'a' and 'b' to be unordered. return false; } + // a_ancestor and b_ancestor must be either both null or both non-null. + CHECK_NE(b_ancestor, nullptr); + CHECK_EQ(a_ancestor->parent(), b_ancestor->parent()); + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); +} + +HloOrderingProto HloOrdering::ToProto() const { + HloOrderingProto proto; + for (const auto& computation : module_->computations()) { + const std::vector* sequence = + SequentialOrder(*computation); + if (sequence != nullptr) { + HloOrderingProto::SequentialComputation* proto_computation = + proto.add_sequential_computations(); + proto_computation->set_computation_name(computation->name()); + for (const HloInstruction* instruction : *sequence) { + *proto_computation->add_instruction_names() = instruction->name(); + } + } + } + return proto; +} + +PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) + : HloOrdering(module) {} + +bool PredecessorHloOrdering::ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const { + CHECK_EQ(a->parent(), b->parent()); + // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'. return strict_predecessors_.at(b->parent())->IsReachable(b, a); } @@ -85,9 +183,9 @@ string DependencyHloOrdering::ToString() const { SequentialHloOrdering::SequentialHloOrdering( const HloModule* module, const HloModuleSequence& module_sequence) - : module_(module) { + : HloOrdering(module), module_sequence_(module_sequence) { // Create a map from instruction to its order position. - for (auto computation_order : module_sequence) { + for (auto computation_order : module_sequence_) { const std::vector& order = computation_order.second; for (int i = 0; i < order.size(); ++i) { DCHECK_EQ(0, order_position_.count(order[i])); @@ -96,12 +194,9 @@ SequentialHloOrdering::SequentialHloOrdering( } } -bool SequentialHloOrdering::ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const { - // Instructions in different computations are unordered. - if (a->parent() != b->parent()) { - return false; - } +bool SequentialHloOrdering::ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const { + CHECK_EQ(a->parent(), b->parent()); // If either instruction is not in the order, then 'a' and 'b' are unordered. if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { return false; @@ -109,6 +204,13 @@ bool SequentialHloOrdering::ExecutesBefore(const HloInstruction* a, return order_position_.at(a) < order_position_.at(b); } +const std::vector* +SequentialHloOrdering::SequentialOrder( + const HloComputation& computation) const { + auto find_it = module_sequence_.find(&computation); + return find_it == module_sequence_.end() ? nullptr : &find_it->second; +} + string SequentialHloOrdering::ToString() const { std::vector pieces; pieces.push_back("SequentialHloOrdering"); @@ -136,6 +238,29 @@ string SequentialHloOrdering::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } +StatusOr MinimumMemoryForSequence( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function) { + if (module_sequence.empty()) { + return 0; + } + + const HloModule* module = module_sequence.begin()->first->parent(); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; +} + namespace { // Class implementing a list scheduler of HLO instructions which produces a @@ -235,7 +360,7 @@ class ListScheduler { return freed_bytes; } - // Construct the scheduling priority of the given instruciton. + // Construct the scheduling priority of the given instruction. Priority GetPriority(const HloInstruction* instruction) { return {BytesFreedIfScheduled(instruction), instruction->user_count()}; } @@ -243,11 +368,24 @@ class ListScheduler { std::vector CreateSchedule() { std::vector schedule; - // Populate the ready list with instructions which have no operands. + // Populate the ready list with instructions which have no operands or + // control predecessors. + std::unordered_map unscheduled_pred_count; std::list ready_list; for (auto& instruction : computation_.instructions()) { - if (instruction->operand_count() == 0 && - instruction->control_predecessors().empty()) { + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (const HloInstruction* user : instruction->users()) { + unscheduled_pred_count[user]++; + } + for (const HloInstruction* succ : instruction->control_successors()) { + unscheduled_pred_count[succ]++; + } + } + for (auto& instruction : computation_.instructions()) { + // Instruction with no operands or control predecessors will + // not be in the map. + if (unscheduled_pred_count.count(instruction.get()) == 0) { ready_list.push_back(instruction.get()); } } @@ -279,28 +417,21 @@ class ListScheduler { } // Add new instructions to ready list. - // TODO(b/34466113): Replace this with successors()/predecessors() when - // predecessor/successor methods are added to HloInstruction. This also - // will resolve the nondeterminism of using a set here assuming - // predecessors/successors is a vector. - std::set successors = best->users(); - successors.insert(best->control_successors().begin(), - best->control_successors().end()); - for (auto* successor : successors) { - std::set predecessors(successor->operands().begin(), - successor->operands().end()); - predecessors.insert(successor->control_predecessors().begin(), - successor->control_predecessors().end()); - bool is_ready = true; - for (auto* predecessor : predecessors) { - if (scheduled_instructions_.count(predecessor) == 0) { - is_ready = false; - break; - } - } - if (is_ready) { - ready_list.push_back(successor); + auto update_pred_count = [&unscheduled_pred_count, + &ready_list](HloInstruction* inst) { + int64 pred_count = --unscheduled_pred_count.at(inst); + CHECK_GE(pred_count, 0); + if (pred_count == 0) { + ready_list.push_back(inst); } + }; + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (HloInstruction* user : best->users()) { + update_pred_count(user); + } + for (HloInstruction* succ : best->control_successors()) { + update_pred_count(succ); } } CHECK_EQ(schedule.size(), computation_.instructions().size()); @@ -327,6 +458,113 @@ class ListScheduler { std::unordered_set scheduled_instructions_; }; +int64 SumLogicalBufferSizes(const std::vector& buffers, + const LogicalBuffer::SizeFunction& size_function) { + int64 size = 0; + for (const LogicalBuffer* buffer : buffers) { + size += size_function(*buffer); + } + return size; +} + +StatusOr> RunDFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + // This ordering is based on DFS post-order, with a heuristic to decide which + // operand to visit first. The heuristic is based on 'extra_users', which is + // simply users-1 for each instruction. By subtracting 1, we're saying that + // instructions with no users or a single user don't count; instructions with + // lots of fan-out will be visited earlier. + tensorflow::gtl::FlatMap extra_users; + tensorflow::gtl::FlatMap total_sizes; + for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; + total_sizes[hlo] = SumLogicalBufferSizes( + points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); + tensorflow::gtl::FlatSet unique_operands( + hlo->operands().begin(), hlo->operands().end()); + for (const HloInstruction* operand : unique_operands) { + extra_users[hlo] += extra_users[operand]; + total_sizes[hlo] += total_sizes[operand]; + } + } + CHECK_EQ(extra_users.size(), computation.instructions().size()); + CHECK_EQ(total_sizes.size(), computation.instructions().size()); + + // Construct a total order based on DFS post-order, visiting operands in + // decreasing cumulative extra user order, and next by cumulative size, with a + // tiebreaker by name for determinism. + std::vector sequence; + FunctionVisitor visitor([&sequence](HloInstruction* hlo) { + sequence.push_back(hlo); + return Status::OK(); + }); + TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( + &visitor, [&extra_users, &total_sizes](const HloInstruction* a, + const HloInstruction* b) { + if (extra_users[a] != extra_users[b]) { + return extra_users[a] > extra_users[b]; + } + if (total_sizes[a] != total_sizes[b]) { + return total_sizes[a] > total_sizes[b]; + } + return a->name() < b->name(); + })); + CHECK_EQ(sequence.size(), computation.instructions().size()); + return sequence; +} + +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + // We try both a list-scheduler based ordering and a DFS based ordering, and + // choose whichever returns a lower min-memory, not accounting for + // fragmentation. + // + // Note that this is just a heuristic. One obvious inaccuracy is that the + // memory required for sub-computations might be different when considered + // within the caller's context. But it's good enough for now. + TF_ASSIGN_OR_RETURN( + std::vector list_sequence, + ListScheduler::Run(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + const int64 list_memory, + MinimumMemoryForComputation(computation, list_sequence, + points_to_analysis, size_function)); + VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; + + TF_ASSIGN_OR_RETURN( + std::vector dfs_sequence, + RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + const int64 dfs_memory, + MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, + size_function)); + VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; + + if (list_memory <= dfs_memory) { + VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; + return list_sequence; + } else { + VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; + return dfs_sequence; + } +} + } // namespace StatusOr @@ -335,16 +573,23 @@ CreateMemoryMinimizingSequence( SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - - for (auto& computation : module.computations()) { - TF_ASSIGN_OR_RETURN( - sequence[computation.get()], - ListScheduler::Run(*computation, *points_to_analysis, size_function)); + for (const auto& computation : module.computations()) { + TF_ASSIGN_OR_RETURN(sequence[computation.get()], + CreateMemoryMinimizingSequence( + *computation, *points_to_analysis, size_function)); } - return sequence; } +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(computation.parent())); + return CreateMemoryMinimizingSequence(computation, *points_to_analysis, + size_function); +} + std::ostream& operator<<( std::ostream& out, const SequentialHloOrdering::HloModuleSequence& module_sequence) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 97f7c6060b8..b59e1ea5eb0 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -31,19 +33,43 @@ limitations under the License. namespace xla { -// Abstract base class for describing a partial ordering of HLO -// instructions. Used to determine live range overlap of HLO instruction output -// buffers. +// Base class for describing a partial ordering of HLO instructions. Used to +// determine live range overlap of HLO instruction output buffers. class HloOrdering { public: - HloOrdering() = default; + HloOrdering(const HloModule* module) + : module_(module), call_graph_(CallGraph::Build(module)) {} virtual ~HloOrdering() = default; // Returns true if instruction 'a' executes before instruction 'b'. This is // not reflexive, that is, an instruction does not execute before itself. - virtual bool ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const = 0; + bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; + + // Returns the sequential instruction order for the given computation, or + // nullptr if the computation does not have a sequential ordering. + virtual const std::vector* SequentialOrder( + const HloComputation& computation) const = 0; + virtual string ToString() const = 0; + + // Returns the serialized representation of this ordering. + // Only sequential computation orders are represented. + HloOrderingProto ToProto() const; + + protected: + // Returns true if instruction 'a' executes before instruction 'b'. + // Precondition: 'a' and 'b' are in the same computation. + // + // Derived classes should implement this method for determining order of + // instructions in the same comptuation. ExecutesBefore() analyzes the + // callgraph and uses this method to determine ordering of instructions in + // different computations. + virtual bool ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const = 0; + + const HloModule* module_; + + std::unique_ptr call_graph_; }; // Base class for partial orderings implemented by a map of strict predecessors @@ -52,20 +78,23 @@ class PredecessorHloOrdering : public HloOrdering { public: ~PredecessorHloOrdering() override = default; - // Returns true if instruction 'a' executes before instruction 'b'. - // Instructions in different computations are not ordered. - bool ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const override; + // Returns nullptr indicating the computation does not have a sequential + // ordering. + const std::vector* SequentialOrder( + const HloComputation& computation) const override { + return nullptr; + } protected: explicit PredecessorHloOrdering(const HloModule* module); string ToStringHelper(const string& name) const; - const HloModule* module_; + bool ExecutesBeforeInSameComputation(const HloInstruction* a, + const HloInstruction* b) const override; - // For each each computation in the module, this is the set of the - // instruction's strict predecessors. An instruction is not an element of its - // own strict predecessor set. + // For each computation in the module, this is the set of the instruction's + // strict predecessors. An instruction is not an element of its own strict + // predecessor set. // // Subclasses should fill this in to define the desired ordering. tensorflow::gtl::FlatMap* SequentialOrder( + const HloComputation& computation) const override; + string ToString() const override; protected: - const HloModule* module_; + bool ExecutesBeforeInSameComputation(const HloInstruction* a, + const HloInstruction* b) const override; + + const HloModuleSequence module_sequence_; // The position of every instruction in the HLO module in its respective // computation sequence (a value of zero indicates the instruction is first in @@ -156,6 +187,16 @@ class SequentialHloOrdering : public HloOrdering { tensorflow::gtl::FlatMap order_position_; }; +std::ostream& operator<<( + std::ostream& out, + const SequentialHloOrdering::HloModuleSequence& module_sequence); + +// Returns the minimum memory required to compute the given module sequence, +// assuming no fragmentation. +StatusOr MinimumMemoryForSequence( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. @@ -163,9 +204,10 @@ StatusOr CreateMemoryMinimizingSequence( const HloModule& module, const LogicalBuffer::SizeFunction& size_function); -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence); +// Overload of above that computes the sequence for a single computation. +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 425bee601a8..21d852a51d6 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -58,26 +58,166 @@ TEST_F(HloOrderingTest, LastUseScheduledFirst) { auto sub = builder.AddInstruction( HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); TF_ASSIGN_OR_ASSERT_OK( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(module, [](const LogicalBuffer& buffer) { + CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. - EXPECT_EQ(module.entry_computation()->instruction_count(), - sequence.at(module.entry_computation()).size()); + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module.entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module.entry_computation()).back()); + EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); + EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); - SequentialHloOrdering ordering(&module, sequence); + SequentialHloOrdering ordering(module.get(), sequence); EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); } +TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { + // Tests the ordering of instructions in different computations using the + // following HLO code: + // + // Entry computation: + // %x = Call(A, {}) + // %y = Call(B, {%x}) + // + // Computation A: + // %a = Call(C, {}) + // + // Computation B: + // %b = Call(C, {}) + // + // Computation C: + // %c = Constant(42.0f) + // + // This results in a diamond-shaped callgraph. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder_c = HloComputation::Builder("C"); + HloInstruction* c = builder_c.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloComputation* computation_c = + module->AddEmbeddedComputation(builder_c.Build()); + + auto builder_b = HloComputation::Builder("B"); + builder_b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* b = builder_b.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {}, computation_c)); + HloComputation* computation_b = + module->AddEmbeddedComputation(builder_b.Build()); + + auto builder_a = HloComputation::Builder("A"); + HloInstruction* a = builder_a.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {}, computation_c)); + HloComputation* computation_a = + module->AddEmbeddedComputation(builder_a.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* x = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {}, computation_a)); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {x}, computation_b)); + module->AddEntryComputation(builder.Build()); + + DependencyHloOrdering ordering(module.get()); + EXPECT_TRUE(ordering.ExecutesBefore(x, y)); + EXPECT_FALSE(ordering.ExecutesBefore(y, x)); + + EXPECT_TRUE(ordering.ExecutesBefore(a, b)); + EXPECT_FALSE(ordering.ExecutesBefore(b, a)); + + EXPECT_FALSE(ordering.ExecutesBefore(a, x)); + EXPECT_TRUE(ordering.ExecutesBefore(a, y)); + EXPECT_FALSE(ordering.ExecutesBefore(x, a)); + EXPECT_FALSE(ordering.ExecutesBefore(y, a)); + + EXPECT_FALSE(ordering.ExecutesBefore(b, x)); + EXPECT_FALSE(ordering.ExecutesBefore(b, y)); + EXPECT_TRUE(ordering.ExecutesBefore(x, b)); + EXPECT_FALSE(ordering.ExecutesBefore(y, b)); + + // Instruction 'c' is called from multiple callsites and should be unordered + // relative to all other instructions in the module. + EXPECT_FALSE(ordering.ExecutesBefore(c, a)); + EXPECT_FALSE(ordering.ExecutesBefore(c, b)); + EXPECT_FALSE(ordering.ExecutesBefore(c, x)); + EXPECT_FALSE(ordering.ExecutesBefore(c, y)); + EXPECT_FALSE(ordering.ExecutesBefore(a, c)); + EXPECT_FALSE(ordering.ExecutesBefore(b, c)); + EXPECT_FALSE(ordering.ExecutesBefore(x, c)); + EXPECT_FALSE(ordering.ExecutesBefore(y, c)); +} + +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, + MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); +} + } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 91468fd35b0..119e2d79022 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -26,6 +25,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +using ::tensorflow::strings::StrAppend; + namespace xla { namespace { @@ -38,32 +39,52 @@ void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module, } // namespace StatusOr HloPassPipeline::Run(HloModule* module) { - legacy_flags::HloPassPipelineFlags* flags = - legacy_flags::GetHloPassPipelineFlags(); - std::vector tmp = - tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ','); - tensorflow::gtl::FlatSet disabled_passes(tmp.begin(), tmp.end()); + run_called_ = true; + + VLOG(1) << "Running HLO pass pipeline " << name(); + + auto repeated_field = + module->config().debug_options().xla_disable_hlo_passes(); + tensorflow::gtl::FlatSet disabled_passes(repeated_field.begin(), + repeated_field.end()); + if (!disabled_passes.empty()) { + VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " + << tensorflow::str_util::Join(disabled_passes, ", "); + } + + auto run_invariant_checkers = [this, module]() -> Status { + for (auto& invariant_checker : invariant_checkers_) { + TF_ASSIGN_OR_RETURN(bool changed, invariant_checker->Run(module)); + TF_RET_CHECK(!changed) << "invariant checkers must not change the graph"; + } + return Status::OK(); + }; string prefix = name().ToString() + ": pipeline start"; bool changed = false; string message; for (auto& pass : passes_) { - if (!disabled_passes.empty() && - disabled_passes.count(pass->name().ToString()) > 0) { + if (disabled_passes.count(pass->name().ToString()) > 0) { + VLOG(1) << " Skipping HLO pass " << pass->name() + << ", disabled by --xla_disable_hlo_passes"; continue; } + VLOG(1) << " HLO pass " << pass->name(); + // Emit label containing: "after foo-pass, before bar-pass". message.clear(); - tensorflow::strings::StrAppend(&message, prefix, ", before ", pass->name()); + StrAppend(&message, prefix, ", before ", pass->name()); DumpModule(dumper_, *module, message); + TF_RETURN_IF_ERROR(run_invariant_checkers()); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); changed |= changed_this_pass; prefix.clear(); - tensorflow::strings::StrAppend(&prefix, name(), ": after ", pass->name()); + StrAppend(&prefix, name(), ": after ", pass->name()); } + TF_RETURN_IF_ERROR(run_invariant_checkers()); DumpModule(dumper_, *module, prefix + ", pipeline end"); return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 7a9c606a487..682c4b952df 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -47,11 +47,23 @@ class HloPassPipeline : public HloPassInterface { // Returns a reference to the added pass. template T& AddPass(Args&&... args) { + CHECK(!run_called_) << "AddPass cannot be called after Run"; auto pass = new T(std::forward(args)...); passes_.push_back(std::unique_ptr(pass)); return *pass; } + // Add an invariant-checking pass to the pipeline. It will be run before and + // after each HLO pass. The invariant checking pass must not mutate the graph + // (it is required to always return "false" from its Run() method). + template + T& AddInvariantChecker(Args&&... args) { + CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run"; + auto pass = new T(std::forward(args)...); + invariant_checkers_.push_back(std::unique_ptr(pass)); + return *pass; + } + // Run all passes on the given HLO module. StatusOr Run(HloModule* module) override; @@ -59,6 +71,8 @@ class HloPassPipeline : public HloPassInterface { const string name_; Compiler::HloDumper dumper_; std::vector> passes_; + std::vector> invariant_checkers_; + bool run_called_ = false; TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline); }; diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc new file mode 100644 index 00000000000..727ad0178c6 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" + +namespace xla { + +HloProto MakeHloProto(const HloModule& module, + const BufferAssignment& assignment) { + HloModuleProto proto_module = module.ToProto(); + HloOrderingProto proto_ordering = + assignment.liveness().hlo_ordering().ToProto(); + BufferAssignmentProto proto_assignment = assignment.ToProto(); + HloProto proto; + proto.mutable_hlo_module()->Swap(&proto_module); + proto.mutable_hlo_ordering()->Swap(&proto_ordering); + proto.mutable_buffer_assignment()->Swap(&proto_assignment); + return proto; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h new file mode 100644 index 00000000000..603259a11fc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities to manipulate data in hlo.proto. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { + +// Returns a serialized representation of the HLO state. +HloProto MakeHloProto(const HloModule& module, + const BufferAssignment& assignment); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 1556d1772f9..a153d73dbd8 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -32,6 +32,16 @@ bool IsConstantR0F32(HloInstruction* instruction, float* out) { return false; } +bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) { + for (const auto& operand : instruction.operands()) { + if (operand->opcode() != HloOpcode::kParameter && + operand->opcode() != HloOpcode::kConstant) { + return false; + } + } + return true; +} + bool AllOperandsAreParameters(const HloInstruction& instruction) { for (const auto& operand : instruction.operands()) { if (operand->opcode() != HloOpcode::kParameter) { @@ -41,6 +51,15 @@ bool AllOperandsAreParameters(const HloInstruction& instruction) { return true; } +bool AllOperandsAreConstants(const HloInstruction& instruction) { + for (const auto& operand : instruction.operands()) { + if (operand->opcode() != HloOpcode::kConstant) { + return false; + } + } + return true; +} + HloInstruction* GetMatchingOperand( std::function matcher, HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index 864f892e920..c79347bbf9d 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -28,9 +28,16 @@ namespace hlo_query { // Precondition: out != nullptr bool IsConstantR0F32(HloInstruction* instruction, float* out); +// Returns whether all of an instruction's operands are of the types constants +// and parameters. +bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction); + // Returns whether all of an instruction's operands are parameters. bool AllOperandsAreParameters(const HloInstruction& instruction); +// Returns whether all of an instruction's operands are constants. +bool AllOperandsAreConstants(const HloInstruction& instruction); + // Returns whether the instruction is a scalar constant. bool IsScalarConstant(const HloInstruction* instruction); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc new file mode 100644 index 00000000000..2c1b0fff4e6 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -0,0 +1,1294 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_rematerialization.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +using ::tensorflow::strings::HumanReadableNumBytes; + +namespace xla { + +namespace { + +// Returns true if the given instruction is rematerializable. +bool IsRematerializable(const HloInstruction* instruction) { + // Conservatively, don't rematerialize instruction with control + // dependencies. For one, control dependencies are added to prevent + // interference of aliased buffers (say, in while bodies) and + // rematerialization is ignorant of liveness and may break the intended + // ordering. + if (!instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + return false; + } + + // Don't rematerialize instructions with side effects, those with a cost that + // might not be captured by HloCostAnalysis, or instructions which cannot be + // cloned safely. + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConstant: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kCustomCall: + case HloOpcode::kOutfeed: + case HloOpcode::kInfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kSend: + case HloOpcode::kTrace: + case HloOpcode::kWhile: + return false; + default: + return true; + } +} + +// Class which maintains an ordered list of instructions with fast insertion +// before arbitrary elements. +class InstructionList { + public: + explicit InstructionList(const std::vector order) { + int64 position = 0; + for (const HloInstruction* inst : order) { + instructions_.push_back(const_cast(inst)); + instruction_iterators_.insert({const_cast(inst), + std::next(instructions_.end(), -1)}); + // Initially position numbers are uniquely assigned in order. Later as + // instructions are added with InsertBefore* methods, some instructions + // may have duplicate position numbers, but the values will be guaranteed + // to be monotonically increasing through the list, and so is still useful + // for quickly(-ish) determining the order of arbitrary instructions in + // the list. + position_number_[inst] = position; + first_at_position_[position] = inst; + position++; + } + } + + // Returns the list of instructions. + const std::list& instructions() const { + return instructions_; + } + + // Insert instruction 'to_insert' immediately before instruction 'before' in + // the list. + void InsertBefore(HloInstruction* to_insert, HloInstruction* before) { + VLOG(3) << "InsertBefore: " << to_insert->name() << " before " + << before->name(); + auto it = instruction_iterators_.find(before); + CHECK(it != instruction_iterators_.end()); + instruction_iterators_.insert( + {to_insert, instructions_.insert(it->second, to_insert)}); + // Assign the same position number to the newly added instruction as + // 'before'. This guarantees monotonicity of the position numbers, but not + // uniqueness. + int64 pos = position_number_.at(before); + position_number_[to_insert] = pos; + if (first_at_position_.at(pos) == before) { + first_at_position_[pos] = to_insert; + } + } + + // Insert instruction 'to_insert' immediately before the earliest instruction + // in 'before_instructions'. + void InsertBeforeInstructions( + HloInstruction* to_insert, + tensorflow::gtl::ArraySlice before_instructions) { + VLOG(3) << "InsertBeforeInstructions: " << to_insert->name() << " before {" + << tensorflow::str_util::Join( + before_instructions, ", ", + [](string* out, HloInstruction* inst) { + tensorflow::strings::StrAppend(out, inst->name()); + }) + << "}"; + + // Find the minimal position number of any instruction in + // 'before_instructions'. + CHECK(!before_instructions.empty()); + int64 min_position_number = std::numeric_limits::max(); + for (const HloInstruction* instruction : before_instructions) { + min_position_number = + std::min(min_position_number, position_number_.at(instruction)); + } + + // Because more than one instruction in 'before_instructions' may have a + // position number of 'min_position_number', find the first such instruction + // with position number 'min_position_number'. + for (auto it = instruction_iterators_.at( + first_at_position_.at(min_position_number)); + it != instructions_.end() && + position_number_.at(*it) == min_position_number; + ++it) { + if (std::find(before_instructions.begin(), before_instructions.end(), + *it) != before_instructions.end()) { + return InsertBefore(to_insert, *it); + } + } + LOG(FATAL) << "Expected to find instruction in before_instructions with " + "position number " + << min_position_number; + } + + private: + // List of instructions. + std::list instructions_; + + // Iterators for each instruction in the list. + tensorflow::gtl::FlatMap::iterator> + instruction_iterators_; + + // A number assigned to each instruction which increases monotonically through + // 'instructions_'. Used to facilitate fast insertion of an instruction before + // the earliest instruction in a set of instructions + // (InsertBeforeInstructions) by enabling fast-ish ordering queries between + // instructions. If position_number_[a] < position_number_[b] then 'a' comes + // before 'b' in the list. If the position numbers are the same then nothing + // can be said about their order without examining the list. + // + // On object construction this value is precisely the instruction's ordinal + // position in the list. Instructions inserted via InsertBefore receive + // duplicate values. However, monotonicity is preserved. + tensorflow::gtl::FlatMap position_number_; + + // The first instruction in the list assigned a particular position number. + tensorflow::gtl::FlatMap first_at_position_; +}; + +// Return the HloInstructions which use the given LogicalBuffer. Sets +// has_indirect_users to whether any of the uses is indirect. A use is indirect +// if the instruction defining logical_buffer is not an operand of the use. This +// can happen via buffer aliasing (eg, tuples). +std::vector GetUsers( + const LogicalBuffer* logical_buffer, + const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) { + std::vector users; + // To identify uses iterate through all HloInstruction users of the + // BufferAliases of the logical buffer. + *has_indirect_users = false; + for (const BufferAlias& buffer_alias : + points_to_analysis.GetBufferAliases(*logical_buffer)) { + for (const HloInstruction* user : buffer_alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(buffer_alias.instruction(), + buffer_alias.index(), user, + points_to_analysis)) { + // The alias may be an operand of 'user', but the LogicalBuffer cannot + // possibly be used by the instruction so ignore 'user'. This is the + // case, for example, for the tuple element buffers in a GetTupleElement + // instruction (the GTE instruction only uses the pointer vector). + continue; + } + if (buffer_alias.instruction() != logical_buffer->instruction()) { + *has_indirect_users = true; + } + // A buffer may be used by the instruction via more than one alias. For + // example, a buffer which appears in more than one element of a tuple. + if (std::find(users.begin(), users.end(), user) == users.end()) { + users.push_back(user); + } + } + } + return users; +} + +// Class for tracking memory usage of a computation as the instructions are +// placed sequentially. Memory usage is the sum of the sizes of live values +// (LogicalBuffers) at the current point in the instruction sequence. +class MemoryUsageTracker { + public: + MemoryUsageTracker( + const HloComputation* computation, + const HloRematerialization::ShapeSizeFunction& size_function, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list); + + // Starts the placement of the given instruction. This adds the sizes of the + // LogicalBuffers defined by the instruction to the current memory + // usage. Placement is broken into two steps (BeginInstruction and + // EndInstruction) to accurately model memory usage. At BeginInstruction the + // memory for the output value(s) of the current instruction is allocated. At + // EndInstruction memory for dead operand(s) is freed. + Status BeginInstruction(const HloInstruction* instruction); + + // Finishes the placement of the current instruction. This frees any dead + // operands or dead result of the instruction. This must be called after + // each call to BeginInstruction. + Status EndInstruction(); + + // Returns the number of bytes that the current memory usage will be reduced + // if the given instruction is rematerialized. + int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const; + + // Adjusts memory usage to account for the rematerialization of + // original_instruction for all remaining unplaced uses. The rematerialization + // is remat_instruction. This method should be called after the HLO graph has + // been transformed (rematerialization instruction created and connected to + // uses). + Status AddRematerializedInstruction(HloInstruction* original_instruction, + HloInstruction* remat_instruction); + + // Returns whether the given instruction has been placed (BeginInstruction + // has been called with 'instruction' as the argument). + bool IsPlaced(const HloInstruction* instruction) const { + return ContainsKey(placed_instructions_, instruction); + } + + // Returns the current memory usage. This is the sum of sizes of all live + // values. + int64 memory_usage() const { return memory_usage_; } + + // Returns the current instruction being placed. + const HloInstruction* in_progress_instruction() const { + return in_progress_instruction_; + } + + // Check invariants of the data structure. This is expensive to call. + bool Check() const; + + string ToString() const; + + private: + // Type holding a unique identifier for each Buffer object. + using BufferId = int64; + + // A Buffer represents a single LogicalBuffer in the computation including + // various metadata useful for tracking liveness of the value. A LogicalBuffer + // is not used directly because the HLO graph is transformed and + // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after + // HLO graph transformations. + struct Buffer { + // The unique id of this Buffer. This value is equal to the buffer's index + // in the vector buffers_. + const BufferId id; + + // The instruction which defines this buffer. + const HloInstruction* defining_instruction; + + // The materialized size of the buffer in bytes. + const int64 size; + + // Whether this buffer is live-out of the computation. + bool live_out; + + // Whether this buffer has indirect uses. Ie, an instruction which is not a + // user of defining_instruction uses this buffer. This can occur due to + // buffer aliasing (eg, tuples). + bool has_indirect_uses; + + // The instructions which use this buffer. + std::vector users; + + // The number of users (HloInstructions) of this buffer which have not yet + // been placed in the sequence. + int64 unfinished_user_count; + + string ToString() const { + return tensorflow::strings::StrCat("Buffer ", id, " (defined by ", + defining_instruction->name(), + ", size ", size, " bytes)"); + } + }; + + // Creates a Buffer representing the given logical buffer. The buffer is added + // to buffers_ and a reference is returned. + Buffer& CreateBufferFromLogicalBuffer( + const LogicalBuffer* logical_buffer, + const TuplePointsToAnalysis& points_to_analysis, + const HloRematerialization::ShapeSizeFunction& size_function, + bool live_out) { + bool has_indirect_uses = false; + std::vector users = + GetUsers(logical_buffer, points_to_analysis, &has_indirect_uses); + return NewBuffer(logical_buffer->instruction(), + size_function(logical_buffer->shape()), std::move(users), + live_out, has_indirect_uses); + } + + // Create a new buffer representing a rematerialization of given buffer for + // the given uses. + Buffer& RematerializeBuffer( + const Buffer& original_buffer, const HloInstruction* remat_instruction, + std::vector&& rematerialized_uses) { + CHECK(IsPlaced(original_buffer.defining_instruction)); + CHECK(!original_buffer.has_indirect_uses); + CHECK(!original_buffer.live_out); + for (const HloInstruction* use : rematerialized_uses) { + CHECK(!IsPlaced(use)); + } + return NewBuffer(remat_instruction, original_buffer.size, + std::move(rematerialized_uses), /*live_out=*/false, + /*has_indirect_uses=*/false); + } + + // Return number of bytes allocated for the buffer with the given id. Buffers + // allocated by the calling computation (eg, parameter and output buffers) are + // considered to have zero bytes because the memory is accounted for in a + // different computation. + int64 AllocatedSize(BufferId buffer_id) const { + const Buffer& buffer = buffers_.at(buffer_id); + HloOpcode def_opcode = buffer.defining_instruction->opcode(); + if (buffer.live_out || def_opcode == HloOpcode::kParameter) { + return 0; + } else { + return buffer.size; + } + } + + // Returns true if BeginInstruction and EndInstruction has been called for the + // given instruction. + bool IsFinished(const HloInstruction* instruction) const { + return IsPlaced(instruction) && instruction != in_progress_instruction_; + } + + // Returns whether the given buffer is being used by the in-progress + // instruction. + bool IsInUse(BufferId buffer_id) const { + if (in_progress_instruction_ == nullptr) { + return false; + } + const std::vector& in_progress_uses = + buffers_used_by_instruction_.at(in_progress_instruction_); + return std::find(in_progress_uses.begin(), in_progress_uses.end(), + buffer_id) != in_progress_uses.end(); + } + + // Returns whether the given instruction is live at the current program + // point. + bool IsCurrentlyLive(BufferId buffer_id) const { + const Buffer& buffer = buffers_[buffer_id]; + return (IsPlaced(buffer.defining_instruction) && + buffer.unfinished_user_count > 0); + } + + // Create a new buffer, add it to buffers_, and return a reference. + Buffer& NewBuffer(const HloInstruction* defining_instruction, int64 size, + std::vector&& users, bool live_out, + bool has_indirect_uses) { + int buffer_id = buffers_.size(); + buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, + has_indirect_uses, users, + static_cast(users.size())}); + return buffers_.back(); + } + + const HloComputation* computation_; + + // Instruction list containing the ordering of instructions in + // computation_. This is the order in which instructions are placed + // (BeginInstruction/EndInstruction calls). + const InstructionList& instruction_list_; + + // Memory usage at the currently placed instruction. + int64 memory_usage_ = 0; + + // The instruction currently being placed. This value is non-null only + // between the calling of BeginInstruction and EndInstruction. + const HloInstruction* in_progress_instruction_ = nullptr; + + // The buffers defined by each instruction. + std::unordered_map> + buffers_defined_by_instruction_; + + // The buffers used by each instruction. + std::unordered_map> + buffers_used_by_instruction_; + + // The set of instructions which have been placed. That is, BeginInstruction + // has been called with the instruction as an argument. + tensorflow::gtl::FlatSet placed_instructions_; + + // All buffers in the computation. + std::vector buffers_; +}; + +MemoryUsageTracker::MemoryUsageTracker( + const HloComputation* computation, + const HloRematerialization::ShapeSizeFunction& size_function, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list) + : computation_(computation), instruction_list_(instruction_list) { + // Iterate through all LogicalBuffers in the computation and gather the + // instructions which define them in buffers_defined_by_instruction_ and the + // instructions which use them in buffers_used_by_instruction_. + for (auto& instruction : computation_->instructions()) { + // Initialize empty vectors for defs and uses of each instruction. + buffers_used_by_instruction_[instruction.get()]; + buffers_defined_by_instruction_[instruction.get()]; + } + + tensorflow::gtl::FlatSet live_out_set = + points_to_analysis.GetPointsToSet(computation_->root_instruction()) + .CreateFlattenedSet(); + tensorflow::gtl::FlatMap + logical_buffer_to_buffer_id; + + for (const HloInstruction* instruction : instruction_list_.instructions()) { + for (const LogicalBuffer* logical_buffer : + points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { + Buffer* buffer; + if (instruction->opcode() == HloOpcode::kWhile) { + // The while instruction defines no new buffers. Instead it reuses the + // buffers of its operand. Find the Buffer of its operand at the + // proper ShapeIndex. + const PointsToSet& operand_points_to = + points_to_analysis.GetPointsToSet(instruction->operand(0)); + CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1); + const LogicalBuffer* source_logical_buffer = + operand_points_to.element(logical_buffer->index())[0]; + buffer = + &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer)); + + // Mark buffer as has indirect use and live out. + buffer->has_indirect_uses = true; + buffer->live_out = + buffer->live_out || ContainsKey(live_out_set, logical_buffer); + + // Add users of while to Buffer users. + bool unused; + for (const HloInstruction* user : + GetUsers(logical_buffer, points_to_analysis, &unused)) { + if (std::find(buffer->users.begin(), buffer->users.end(), user) == + buffer->users.end()) { + buffer->users.push_back(user); + buffer->unfinished_user_count++; + buffers_used_by_instruction_.at(user).push_back(buffer->id); + } + } + } else { + buffer = &CreateBufferFromLogicalBuffer( + logical_buffer, points_to_analysis, size_function, + ContainsKey(live_out_set, logical_buffer)); + buffers_defined_by_instruction_.at(instruction).push_back(buffer->id); + for (const HloInstruction* user : buffer->users) { + buffers_used_by_instruction_.at(user).push_back(buffer->id); + } + } + + logical_buffer_to_buffer_id[logical_buffer] = buffer->id; + } + } + XLA_VLOG_LINES(10, ToString()); + DCHECK(Check()); +} + +Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) { + VLOG(3) << "BeginInstruction " << instruction->name(); + TF_RET_CHECK(in_progress_instruction_ == nullptr); + in_progress_instruction_ = instruction; + + placed_instructions_.insert(in_progress_instruction_); + + // All buffers defined by this instruction need memory. + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString() + << " is now live."; + memory_usage_ += AllocatedSize(buffer_id); + } + + // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead) + // operand. Account for this potential reuse here. + + VLOG(3) << " memory usage = " << memory_usage_; + VLOG(10) << ToString(); + + DCHECK(Check()); + return Status::OK(); +} + +Status MemoryUsageTracker::EndInstruction() { + TF_RET_CHECK(in_progress_instruction_ != nullptr); + VLOG(3) << "EndInstruction " << in_progress_instruction_->name(); + + for (BufferId buffer_id : + buffers_used_by_instruction_.at(in_progress_instruction_)) { + Buffer& buffer = buffers_.at(buffer_id); + buffer.unfinished_user_count--; + CHECK_GE(buffer.unfinished_user_count, 0) + << buffer.ToString() << " has negative unfinished use count."; + if (buffer.unfinished_user_count == 0) { + // Buffer is now dead. + VLOG(3) << " " << buffer.ToString() << " is now dead."; + memory_usage_ -= AllocatedSize(buffer_id); + CHECK_GE(memory_usage_, 0); + } + } + + // If any buffer defined by this instruction has no uses, then memory can be + // reclaimed immediately. + for (BufferId buffer_id : + buffers_defined_by_instruction_.at(in_progress_instruction_)) { + const Buffer& buffer = buffers_.at(buffer_id); + if (buffer.unfinished_user_count == 0) { + VLOG(3) << " " << buffer.ToString() << " is immediately dead."; + memory_usage_ -= AllocatedSize(buffer_id); + CHECK_GE(memory_usage_, 0); + } + } + + in_progress_instruction_ = nullptr; + + VLOG(3) << " memory usage = " << memory_usage_; + VLOG(10) << ToString(); + + DCHECK(Check()); + + return Status::OK(); +} + +int64 MemoryUsageTracker::MemoryReducedIfRematerialized( + const HloInstruction* instruction) const { + CHECK_NE(in_progress_instruction_, nullptr); + if (!IsPlaced(instruction) || instruction == in_progress_instruction_) { + return 0; + } + + // TODO(b/37687140): Rematerialization can increase peak memory consumption at + // an earlier point in the program if rematerialization extends the live range + // of the operand of the instruction being rematerialized across the live + // range of the value of instruction being rematerialized. Don't rematerialize + // in this case (ie, return 0 here). + + // Compute the amount of memory reduced (if any) by rematerializing + // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer + // be live at this program point, so initially set memory_reduced to the + // size of its defined values. + int64 memory_reduced = 0; + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + // Avoid rematerializing instructions with indirect uses as it is difficult + // to reason about liveness after rematerializing the instruction. + // TODO(b/37714814): Consider rematerialzing instructions with indirect + // uses. + if (buffers_.at(buffer_id).has_indirect_uses) { + return 0; + } + + if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) { + memory_reduced += AllocatedSize(buffer_id); + } + } + + // Account for any logical buffers whose live range must be extended across + // this program point. + for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + if (!IsCurrentlyLive(buffer_id)) { + // This logical buffer is used by 'instruction' but is not live at this + // program point. Rematerializing 'instruction' will extend the buffer's + // live range across this program point. + memory_reduced -= AllocatedSize(buffer_id); + } + } + + return memory_reduced; +} + +Status MemoryUsageTracker::AddRematerializedInstruction( + HloInstruction* original_instruction, HloInstruction* remat_instruction) { + VLOG(3) << "AddRematerializedInstruction: original_instruction = " + << original_instruction->name() + << ", remat_instruction = " << remat_instruction->name(); + + TF_RET_CHECK(in_progress_instruction_ != nullptr); + TF_RET_CHECK(IsPlaced(original_instruction)); + TF_RET_CHECK(!IsPlaced(remat_instruction)); + CHECK(!ContainsKey(buffers_defined_by_instruction_, remat_instruction)); + CHECK(!ContainsKey(buffers_used_by_instruction_, remat_instruction)); + + // Construct the list of buffers used and defined by the rematerialization. + buffers_defined_by_instruction_[remat_instruction]; + buffers_used_by_instruction_[remat_instruction] = + buffers_used_by_instruction_.at(original_instruction); + + // Account for the additional buffer uses created by the new rematerialization + // instruction. Update memory usage if the rematerialization makes a dead + // buffer live again. + for (BufferId buffer_id : + buffers_used_by_instruction_.at(original_instruction)) { + Buffer& buffer = buffers_.at(buffer_id); + if (buffer.unfinished_user_count == 0) { + // Buffer used by this instruction was dead, now is alive. + memory_usage_ += AllocatedSize(buffer.id); + } + + buffer.unfinished_user_count++; + buffer.users.push_back(remat_instruction); + } + + // Create a new set of Buffers defined by the new rematerialization + // instruction. Update the internal data structures and memory use to account + // for them. + for (BufferId old_buffer_id : + buffers_defined_by_instruction_.at(original_instruction)) { + Buffer& old_buffer = buffers_.at(old_buffer_id); + + std::vector placed_users; + std::vector unplaced_users; + for (const HloInstruction* user : old_buffer.users) { + if (IsPlaced(user)) { + CHECK(IsFinished(user)); + placed_users.push_back(user); + } else { + unplaced_users.push_back(user); + } + } + old_buffer.users = std::move(placed_users); + old_buffer.unfinished_user_count = 0; + + // Buffer is now dead. + memory_usage_ -= AllocatedSize(old_buffer.id); + + Buffer& new_buffer = RematerializeBuffer(old_buffer, remat_instruction, + std::move(unplaced_users)); + + buffers_defined_by_instruction_.at(remat_instruction) + .push_back(new_buffer.id); + for (const HloInstruction* user : new_buffer.users) { + std::vector& buffers_used = + buffers_used_by_instruction_.at(user); + std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id, + new_buffer.id); + } + } + + VLOG(3) << " memory usage = " << memory_usage_; + XLA_VLOG_LINES(10, ToString()); + + DCHECK(Check()); + + return Status::OK(); +} + +string MemoryUsageTracker::ToString() const { + string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", + computation_->name(), "\n"); + tensorflow::strings::StrAppend( + &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", + memory_usage(), " bytes)"); + for (const HloInstruction* instruction : instruction_list_.instructions()) { + string inprogress = + instruction == in_progress_instruction_ ? " in-progress" : ""; + string placed = IsPlaced(instruction) ? " placed" : ""; + tensorflow::strings::StrAppend(&output, " ", instruction->name(), + inprogress, placed, "\n Defines:\n"); + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + const Buffer& buffer = buffers_[buffer_id]; + string live = IsCurrentlyLive(buffer_id) ? " live" : ""; + tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, + ", ", buffer.unfinished_user_count, + " unfinished uses\n"); + } + tensorflow::strings::StrAppend(&output, " Uses:\n"); + for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + tensorflow::strings::StrAppend(&output, " ", + buffers_[buffer_id].ToString(), "\n"); + } + } + return output; +} + +bool MemoryUsageTracker::Check() const { + auto elements_are_unique = [](const std::vector& vec) { + return vec.size() == std::set(vec.begin(), vec.end()).size(); + }; + + // Verify buffers_defined_by_instruction_. + for (auto& instruction : computation_->instructions()) { + const std::vector& defined_buffers = + buffers_defined_by_instruction_.at(instruction.get()); + CHECK(elements_are_unique(defined_buffers)) + << "Instruction " << instruction->name() + << " does not have unique defined buffers: " + << tensorflow::str_util::Join( + defined_buffers, ", ", [this](string* out, BufferId buffer_id) { + tensorflow::strings::StrAppend( + out, buffers_.at(buffer_id).ToString()); + }); + + for (const Buffer& buffer : buffers_) { + if (buffer.defining_instruction == instruction.get()) { + CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), + buffer.id) != defined_buffers.end()) + << "Instruction " << instruction->name() + << " defined buffers is missing: " << buffer.ToString(); + } + } + } + + // Verify buffers_used_by_instruction_. + for (auto& instruction : computation_->instructions()) { + const std::vector& used_buffers = + buffers_used_by_instruction_.at(instruction.get()); + CHECK(elements_are_unique(used_buffers)) + << "Instruction " << instruction->name() + << " does not have unique used buffers: " + << tensorflow::str_util::Join( + used_buffers, ", ", [this](string* out, BufferId buffer_id) { + tensorflow::strings::StrAppend( + out, buffers_.at(buffer_id).ToString()); + }); + } + for (const Buffer& buffer : buffers_) { + int64 unfinished_uses = 0; + for (const HloInstruction* user : buffer.users) { + const std::vector& used_buffers = + buffers_used_by_instruction_.at(user); + CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != + used_buffers.end()) + << "Instruction " << user->name() << " used buffers is missing " + << buffer.ToString(); + if (!IsFinished(user)) { + unfinished_uses++; + } + } + CHECK_EQ(buffer.unfinished_user_count, unfinished_uses) + << "Incorrect unplaced use count for " << buffer.ToString(); + } + + // Verify live set size against memory_usage_. + int64 live_size = 0; + for (const Buffer& buffer : buffers_) { + // The while instruction reuses its input buffers as output buffers so + // don't double count its buffers if it is currently executing. + if (IsCurrentlyLive(buffer.id) && + !(buffer.defining_instruction == in_progress_instruction_ && + in_progress_instruction_->opcode() == HloOpcode::kWhile)) { + live_size += AllocatedSize(buffer.id); + } + } + CHECK(live_size == memory_usage_) + << "Live set size " << live_size << " is not same as memory usage " + << memory_usage_ + << ". This could happen if some nodes defined in the " + "computation are not being used/executed."; + + return true; +} + +// Computes and returns the cost of rematerializing the given instruction. +// Cost per rematerialized instruction is defined as: +// +// (flop_count + transcendental_count + element_count) / memory_reduced +// +// flop_count: from HloCostAnalysis +// transcendental_count: from HloCostAnalysis +// element_count: number of elements accessed in operands and output of +// instruction +// memory_reduced: The memory usage reduced by rematerializing the +// instruction. +// +// This is a rough estimate of the extra execution time per byte saved by +// rematerializing this instruction for its remaining uses. In general, we +// want the most memory saving for the least latency penalty which is captured +// by this heuristic. +int64 RematerializationCost(const HloInstruction* instruction, + const MemoryUsageTracker& memory_tracker, + const HloCostAnalysis& cost_analysis, + int64 memory_reduced) { + // If none of the users of 'instruction' have been placed in the sequence (as + // tracked by memory_tracker), then rematerialization of 'instruction' is a + // zero-cost move of 'instruction' in the sequence. + if (!std::any_of(instruction->users().begin(), instruction->users().end(), + [&memory_tracker](const HloInstruction* inst) { + return memory_tracker.IsPlaced(inst); + })) { + return 0; + } + + CHECK_GT(memory_reduced, 0); + const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); + const int64 elements_accessed = + ShapeUtil::IsTuple(instruction->shape()) + ? bytes_accessed + : bytes_accessed / ShapeUtil::ByteSizeOfPrimitiveType( + instruction->shape().element_type()); + + // Multiply by 256 to improve precision of cost. Without this factor, + // many instructions such as many elementwise instructions would have + // zero cost because the bytes reduced can be several times greater than + // the element count. + return 256 * + (cost_analysis.flop_count(*instruction) + + cost_analysis.transcendental_count(*instruction) + + elements_accessed) / + memory_reduced; +} + +// Selects and returns the best candidate instruction for rematerialization. +// The instruction with lowest rematerialization cost is selected among those +// candidate which reduce memory use at the program point of the current +// instruction as indicated by memory_tracker. nullptr is returned if no +// candidate can be found. +HloInstruction* PickRematerializationCandidate( + const MemoryUsageTracker& memory_tracker, + const InstructionList& instruction_list, + const HloCostAnalysis& cost_analysis, + const tensorflow::gtl::FlatSet& blacklist) { + HloInstruction* best = nullptr; + int64 best_cost = 0; + + // TODO(b/35244891): This is currently quadratic in the number of HLO + // instructions. + for (HloInstruction* candidate : instruction_list.instructions()) { + if (!memory_tracker.IsPlaced(candidate)) { + // Only iterate up to the currently placed instruction as indicated by + // memory_tracker. We are trying to reduce memory usage at the placed + // instruction so rematerializing later values is of no benefit. + break; + } + VLOG(5) << "considering rematerialization candidate " << candidate->name(); + + if (ContainsKey(blacklist, candidate)) { + // Skip instructions on the blacklist to avoid infinite loops of + // rematerializing the same instruction(s) repeatedly. + VLOG(5) << "candidate " << candidate->name() + << " is excluded from rematerialization"; + continue; + } + + if (!IsRematerializable(candidate)) { + VLOG(5) << "candidate " << candidate->name() + << " not viable: is not rematerializable"; + continue; + } + + const int64 memory_reduced = + memory_tracker.MemoryReducedIfRematerialized(candidate); + + if (memory_reduced <= 0) { + VLOG(5) << "candidate " << candidate->name() + << " memory reduced = " << memory_reduced << " <= 0"; + continue; + } + + const int cost = RematerializationCost(candidate, memory_tracker, + cost_analysis, memory_reduced); + + VLOG(5) << "candidate " << candidate->name() << ", memory reduced " + << memory_reduced << ", cost per byte " << cost; + + if (best == nullptr || cost < best_cost) { + VLOG(5) << "candidate " << candidate->name() << " now best"; + best = candidate; + best_cost = cost; + } + } + return best; +} + +} // namespace + +StatusOr HloRematerialization::ComputePeakMemory( + const HloComputation* computation, + const std::vector& order) const { + InstructionList instruction_list(order); + MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, + instruction_list); + int64 peak_memory = tracker.memory_usage(); + for (const HloInstruction* instruction : order) { + TF_RETURN_IF_ERROR(tracker.BeginInstruction(instruction)); + TF_ASSIGN_OR_RETURN(int64 callee_usage, + CalledComputationsMemoryUsage(instruction)); + peak_memory = + std::max(peak_memory, tracker.memory_usage() + callee_usage); + TF_RETURN_IF_ERROR(tracker.EndInstruction()); + } + VLOG(1) << "Peak memory for " << computation->name() << ": " + << HumanReadableNumBytes(peak_memory); + return peak_memory; +} + +StatusOr HloRematerialization::CalledComputationsMemoryUsage( + const HloInstruction* instruction) const { + const CallSite* callsite = + call_graph_->GetNode(instruction->parent()).GetCallSite(instruction); + if (callsite == nullptr || callsite->context() == CallContext::kParallel) { + return 0; + } + int64 callee_usage = 0; + for (const HloComputation* computation : callsite->called_computations()) { + TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation)); + callee_usage += computation_peak_memory_.at(computation); + } + return callee_usage; +} + +StatusOr HloRematerialization::RematerializeComputation( + HloComputation* computation, + SequentialHloOrdering::HloModuleSequence* sequence, + int64 memory_limit_bytes) { + VLOG(1) << "Rematerializing computation " << computation->name() + << " with limit " << HumanReadableNumBytes(memory_limit_bytes); + VLOG(1) << "peak memory usage is " + << HumanReadableNumBytes(computation_peak_memory_.at(computation)); + CHECK(!ContainsKey(rematerialized_computations_, computation)); + + InstructionList instruction_list(sequence->at(computation)); + MemoryUsageTracker memory_tracker(computation, size_function_, + *points_to_analysis_, instruction_list); + bool changed = false; + + // To avoid an infinite loop rematerializing the same set of instructions ad + // infinitum, keep a blacklist of instructions which should not be + // rematerialized. + tensorflow::gtl::FlatSet blacklist; + + // If the rematerialization makes the source instruction dead, then the + // rematerialization is added to 'remat_move_instructions' (the + // rematerialization is essentially a move). If the next rematerialization of + // the instruction is also a move then the rematerialization is added to the + // blacklist. + tensorflow::gtl::FlatSet remat_move_instructions; + + // The peak memory of the computation at any point in the instruction + // sequence. + int64 peak_memory = memory_tracker.memory_usage(); + + // Total count of instructions rematerialized. + int64 remat_count = 0; + // Total count of clones created minus number of original rematerialized + // instructions which are dead. + int64 net_instructions_added = 0; + + const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); + + // Iterate through all instructions in the sequence. At each instruction + // (program point) if memory_usage exceeds the specified limit then + // rematerialize HLO instructions until memory_usage is reduced. + int64 instruction_index = 0; + for (auto list_it = instruction_list.instructions().begin(); + list_it != instruction_list.instructions().end(); ++list_it) { + HloInstruction* instruction = *list_it; + TF_ASSIGN_OR_RETURN(int64 callee_usage, + CalledComputationsMemoryUsage(instruction)); + TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(instruction)); + + VLOG(2) << "Program point at " << instruction->name() + << ", memory usage = " << memory_tracker.memory_usage() + << ", callee usage = " << callee_usage << ", [" << instruction_index + << "/" << instruction_list.instructions().size() << "]"; + instruction_index++; + + while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { + VLOG(2) << "Over memory limit at instruction " << instruction->name() + << ", using " + << HumanReadableNumBytes(memory_tracker.memory_usage() + + callee_usage) + << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); + + HloInstruction* best = PickRematerializationCandidate( + memory_tracker, instruction_list, cost_analysis_, blacklist); + + if (best == nullptr) { + VLOG(3) << "Unable to find rematerialization candidate at program " + "point " + << instruction->name() << ". Memory usage = " + << HumanReadableNumBytes(memory_tracker.memory_usage() + + callee_usage); + break; + } + + VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " + << memory_tracker.MemoryReducedIfRematerialized(best) << ")"; + changed = true; + remat_count++; + + HloInstruction* remat = + computation->AddInstruction(best->Clone(/*suffix=*/"remat")); + + // Replace each remaining use of 'best' with the rematerialization. + std::vector best_users_copy = best->users(); + for (HloInstruction* user : best_users_copy) { + if (!memory_tracker.IsPlaced(user)) { + VLOG(2) << " Replacing use of " << best->name() << " in " + << user->name() << " with " << remat->name(); + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); + } + } + + // Account for the rematerialization in the memory tracker. + TF_RETURN_IF_ERROR( + memory_tracker.AddRematerializedInstruction(best, remat)); + + // Insert rematerialized instruction right before the earliest unplaced + // use of the instruction *and* the earliest unplaced last use of any + // operands of remat. Unplaced uses of the remat's operands are included + // because we don't want to extend the live range of remat's operands as + // this could increase memory usage. + std::vector place_before = remat->users(); + for (auto* operand : remat->operands()) { + for (auto* operand_user : operand->users()) { + if (!memory_tracker.IsPlaced(operand_user) && operand_user != remat) { + place_before.push_back(operand_user); + } + } + } + instruction_list.InsertBeforeInstructions(remat, place_before); + + // If the rematerialized instruction is dead then rematerialization is + // essentially a move. Don't delete the instruction now because we don't + // want duplicate HloInstruction* values during the course of the + // transformation because we keep maps with HloInstruction* values as + // keys. + if (best->users().empty()) { + VLOG(2) << best->name() << " is now dead"; + if (ContainsKey(remat_move_instructions, best)) { + // Previously, 'best' was a rematerialization which killed the + // instruction it was a copying of. Now 'remat' is a rematerialization + // of 'best' and kills 'best'. Stop rematerializing this instruction + // to avoid an infinite loop. + blacklist.insert(remat); + } + remat_move_instructions.insert(remat); + } else { + net_instructions_added++; + } + + VLOG(3) << "memory_usage after rematerialization = " + << memory_tracker.memory_usage(); + } + + const CallSite* callsite = call_graph_node.GetCallSite(instruction); + if (callsite != nullptr && + callsite->context() == CallContext::kSequential && + memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { + // Memory usage exceeds the limit. Try to rematerialize any + // subcomputation(s) that this instruction calls. + VLOG(1) << "Memory usage still over the limit (" + << (memory_tracker.memory_usage() + callee_usage) << " > " + << memory_limit_bytes + << "). Rematerializing computations called by " + << instruction->name(); + + // Recompute callee usage to account for any rematerialization performed + // in the callee computations. + for (HloComputation* called_computation : + callsite->called_computations()) { + if (!ContainsKey(rematerialized_computations_, called_computation)) { + // Memory limit for the subcomputation is the memory limit less the + // amount of memory used at this point in the computation. + int64 subcomputation_memory_limit_bytes = std::max( + 0, memory_limit_bytes - memory_tracker.memory_usage()); + TF_ASSIGN_OR_RETURN( + bool subcomputation_changed, + RematerializeComputation(called_computation, sequence, + subcomputation_memory_limit_bytes)); + changed |= subcomputation_changed; + } + } + TF_ASSIGN_OR_RETURN(callee_usage, + CalledComputationsMemoryUsage(instruction)); + } + + peak_memory = std::max(peak_memory, + memory_tracker.memory_usage() + callee_usage); + VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory); + + TF_RETURN_IF_ERROR(memory_tracker.EndInstruction()); + } + + // Verify some invariants on the memory tracker. + CHECK_EQ(memory_tracker.memory_usage(), 0); + for (auto& instruction : computation->instructions()) { + CHECK(memory_tracker.IsPlaced(instruction.get())); + } + + VLOG(1) << "In computation " << computation->name() << " rematerialized " + << remat_count << " instructions; " << net_instructions_added + << " net instructions added"; + VLOG(1) << " peak memory usage now " << HumanReadableNumBytes(peak_memory) + << " (was " + << HumanReadableNumBytes(computation_peak_memory_.at(computation)) + << ")"; + + // Update peak memory used by computation. + computation_peak_memory_.at(computation) = peak_memory; + + // Update order to include rematerialized instructions. + sequence->at(computation) + .assign(instruction_list.instructions().begin(), + instruction_list.instructions().end()); + + rematerialized_computations_.insert(computation); + + instructions_rematerialized_ += remat_count; + net_instructions_added_ += net_instructions_added; + + return changed; +} + +StatusOr HloRematerialization::Run( + HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, + int64 memory_limit_bytes) { + // The sequence is constructed entirely by this method. + TF_RET_CHECK(sequence->empty()); + + VLOG(1) << "HloRematerialization() with memory limit of " + << HumanReadableNumBytes(memory_limit_bytes); + + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); + + // Adjust memory limit to account for the output of the entry + // computation. This is necessary because the per-computation accounting in + // MemoryUsageTracker do not include output as these are typically allocated + // by the caller. + int64 module_output_size = 0; + ShapeUtil::ForEachSubshape( + module->entry_computation()->root_instruction()->shape(), + [&module_output_size, this](const Shape& subshape, + const ShapeIndex& /*index*/) { + module_output_size += size_function_(subshape); + }); + + const int64 adjusted_memory_limit_bytes = + memory_limit_bytes - module_output_size; + VLOG(1) << "Adjusted memory limit accounting for output (" + << HumanReadableNumBytes(module_output_size) + << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); + + XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); + // Create initial sequence of HLO instructions. + TF_ASSIGN_OR_RETURN(*sequence, + CreateMemoryMinimizingSequence( + *module, [this](const LogicalBuffer& buffer) { + return size_function_(buffer.shape()); + })); + // Compute peak memory usage of all computations in the module called in a + // sequential context. + call_graph_ = CallGraph::Build(module); + TF_RETURN_IF_ERROR(call_graph_->VisitNodes( + [this, sequence](const CallGraphNode& node) -> Status { + if (node.context() == CallContext::kSequential) { + TF_ASSIGN_OR_RETURN( + computation_peak_memory_[node.computation()], + ComputePeakMemory(node.computation(), + sequence->at(node.computation()))); + } + return Status::OK(); + })); + + // The peak memory usage of the module equals the peak memory use of the entry + // computation plus the output size of the computation. This is because the + // peak memory for a computation does not include the output as this is + // typically accounted for in the caller. + const int64 before_peak_memory = + computation_peak_memory_.at(module->entry_computation()) + + module_output_size; + VLOG(1) << "Peak memory usage of module (before): " + << HumanReadableNumBytes(before_peak_memory); + + // Run cost analysis. Operation cost is used in the heuristic for selecting + // instructions for rematerialization. + TF_RETURN_IF_ERROR( + module->entry_computation()->root_instruction()->Accept(&cost_analysis_)); + + // Subcomputations called by the entry computation will also be + // rematerialized. + TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( + module->entry_computation(), sequence, + adjusted_memory_limit_bytes)); + + // Rematerialization can introduce dead code. This occurs if all uses of an + // instruction are replaced with rematerializations of the instruction. + TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module)); + changed |= dead_code_removed; + + // After DCE, the module sequence may include instructions which no longer + // exist. + for (const auto& computation : module->computations()) { + if (sequence->at(computation.get()).size() != + computation->instruction_count()) { + // A size mismatch between the computation instruction count and the size + // of the ordering of instructions can only be caused by DCE. Rebuild the + // order by removing the deleted instructions from the order. + tensorflow::gtl::FlatSet instruction_set; + for (const auto& instruction : computation->instructions()) { + instruction_set.insert(instruction.get()); + } + // Move the old order into a temporary vector, then build new order + // inplace. + std::vector& order = + sequence->at(computation.get()); + std::vector old_order; + using std::swap; + swap(order, old_order); + std::copy_if(old_order.begin(), old_order.end(), + std::back_inserter(order), + [&instruction_set](const HloInstruction* instruction) { + return ContainsKey(instruction_set, instruction); + }); + TF_RET_CHECK(sequence->at(computation.get()).size() == + computation->instruction_count()); + } + } + VLOG(1) << "Rematerialized " << instructions_rematerialized_ + << " instructions in module " << module->name() << "; " + << net_instructions_added_ << " net instructions added"; + const int64 current_peak_memory = + computation_peak_memory_.at(module->entry_computation()) + + module_output_size; + VLOG(1) << "Peak memory usage of module now " + << HumanReadableNumBytes(current_peak_memory) << " (" + << current_peak_memory << " bytes), was " + << HumanReadableNumBytes(before_peak_memory) << " (" + << before_peak_memory << " bytes)"; + const int64 reduced_peak_memory = before_peak_memory - current_peak_memory; + VLOG(1) << "Reduced peak memory by " + << HumanReadableNumBytes(reduced_peak_memory) << " (" + << reduced_peak_memory << " bytes)"; + + XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); + + if (current_peak_memory > memory_limit_bytes) { + LOG(WARNING) << "Can't reduce memory use below " + << HumanReadableNumBytes(memory_limit_bytes) + << " by rematerialization (only reduced to " + << HumanReadableNumBytes(current_peak_memory) << ")"; + } + + return changed; +} + +/* static */ StatusOr HloRematerialization::RematerializeAndSchedule( + const HloRematerialization::ShapeSizeFunction& size_function, + int64 memory_limit_bytes, HloModule* hlo_module, + SequentialHloOrdering::HloModuleSequence* sequence) { + HloRematerialization remat(size_function); + return remat.Run(hlo_module, sequence, memory_limit_bytes); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h new file mode 100644 index 00000000000..1693f93183b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -0,0 +1,133 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ + +#include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" + +namespace xla { + +class HloRematerialization { + public: + using ShapeSizeFunction = std::function; + + // Rematerialize HLO instructions in the given module to reduce peak memory + // use below memory_limit_bytes where memory use is defined as the total size + // of all live HLO instruction values. Parameters and constants are included + // in memory use estimates. Method parameters: + // + // size_function: Function which returns the size in bytes of the top-level + // buffer of the given shape. + // + // memory_limit_bytes: The threshold number of bytes to reduce memory use to + // via rematerialization. + // + // hlo_module: HLO module to rematerialize instructions in. + // + // sequence: Should point to an empty HloModuleSequence. Upon return + // contains the HLO instruction order which was used for + // rematerialization. This is the order in which HLO instructions should + // be emitted to minimize memory use. + // + // Returns whether any instructions were rematerialized. If memory use is + // already below the given limit then no instructions are rematerialized and + // false is returned. + // + // CSE will undo the effects of this optimization and should not be run after + // this pass. In general, this pass should be run very late immediately before + // code generation. + static StatusOr RematerializeAndSchedule( + const ShapeSizeFunction& size_function, int64 memory_limit_bytes, + HloModule* hlo_module, + SequentialHloOrdering::HloModuleSequence* sequence); + + protected: + HloRematerialization(const ShapeSizeFunction& size_function) + : size_function_(size_function), cost_analysis_(size_function_) {} + ~HloRematerialization() {} + + // Runs rematerialization on the given module. Returns whether the module was + // changed. memory_limit is the target maximum peak memory usage by the + // module. sequence should be an empty HloModuleSequence. Upon return sequence + // contains the memory-minimizing order in which to emit the HLO instructions. + StatusOr Run(HloModule* module, + SequentialHloOrdering::HloModuleSequence* sequence, + int64 memory_limit); + + // Rematerializes instructions within the given computation. 'order' is the + // order in which the computation's instructions will be emitted in the + // backend. Rematerialized instructions will be added to the HLO computation + // and inserted into 'order'. + StatusOr RematerializeComputation( + HloComputation* computation, + SequentialHloOrdering::HloModuleSequence* sequence, + int64 computation_memory_limit); + + // Computes and returns the peak memory used by the given computation. The + // peak memory is the maximum total size of all live HLO instruction values at + // any program point. 'order' is the order in which the HLO instructions will + // be emitted which is used to determine lifespans of HLO values. + StatusOr ComputePeakMemory( + const HloComputation* computation, + const std::vector& order) const; + + // Returns the peak memory usage of the called computations for the given + // instruction. Zero is returned if the instruction calls no computations. + StatusOr CalledComputationsMemoryUsage( + const HloInstruction* instruction) const; + + // Function which computes the size of the top-level buffer of a shape. + const ShapeSizeFunction size_function_; + + // Call graph of the hlo_module. + std::unique_ptr call_graph_; + + // Analysis used for computing the rematerialization cost of instructions. + HloCostAnalysis cost_analysis_; + + // The peak memory usage of each computation. The map contains only those + // computations called from sequential context + // (CallContext::kSequential). These values are updated as rematerialization + // occurs. + tensorflow::gtl::FlatMap + computation_peak_memory_; + + std::unique_ptr points_to_analysis_; + + // Set of computations which have had rematerialization + // applied. Rematerialization is only applied once per computation. + tensorflow::gtl::FlatSet rematerialized_computations_; + + // Count of the total instructions rematerialized. + int64 instructions_rematerialized_ = 0; + + // Count of the net instructions added to the HLO module by + // rematerialization. This can be different than instructions_rematerialized_ + // because some rematerializations are effectively moves in the HLO + // schedule. In these cases, the rematerialization instruction replaces all + // uses of the original instruction and the original instruction is + // dead. Hence, no net instructions were added. + int64 net_instructions_added_ = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc new file mode 100644 index 00000000000..f306bcc309c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -0,0 +1,531 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_rematerialization.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +using ::testing::_; + +class HloRematerializationTest : public HloTestBase { + protected: + // Creates and returns a computation which can benefit from + // rematerialization. The computation looks like: + // + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %negate = negate(%bcast) + // F32[2048] %concat_1 = concat({%negate, %negate}) + // F32[1] %slice_1 = slice(%concat_1, {0:1}) + // F32[1025] %concat_2 = concat({%bcast, %slice_1}) + // F32[1] %slice_2 = slice(%concat_2, {0:1}); + // + // The instruction %bcast can be rematerialized before its use at %concat_2 + // to reduce peak memory usage. This avoids %bcast and %concat_1 being + // simultaneously live. Peak memory use is about 16KB before rematerialization + // (during execution of %concat_1) and about 12KB after rematerializing %bcast + // for its use in %concat_2. + std::unique_ptr MakeRematerializableComputation( + const string& suffix = "") { + auto builder = HloComputation::Builder(TestName() + suffix); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); + auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {negate, negate}, + /*dimension=*/0)); + auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice( + vec1_shape_, concat_1, /*start_indices=*/{0}, + /*limit_indices=*/{1}, + /*strides=*/{1})); + auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1}, + /*dimension=*/0)); + // Add a final slice to make the parameter shape match the output shape + // which is necessary to use this computation in a while. + builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2, + /*start_indices=*/{0}, + /*limit_indices=*/{1}, + /*strides=*/{1})); + return builder.Build(); + } + + // Creates and returns a computation which includes a while and can benefit + // from rematerialization. The computation looks like: + // + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1] %slice_1 = slice(%bcast, {0:1}) + // F32[1] %while = while(%slice_1, while_body, while_cond) + // F32[1025] %concat = concat({%bcast, %while}) + // F32[1] %slice_2 = slice(%concat, {0:1}); + // + // The instruction %bcast can be rematerialized before its use at %concat to + // reduce peak memory usage. This avoids %bcast being live during execution of + // the while. Peak memory use is maximum of 8K and 4K plus the memory use of + // the while subcomputations. + std::unique_ptr MakeRematerializableWhileComputation( + HloComputation* while_cond, HloComputation* while_body, + const string& suffix = "") { + auto builder = HloComputation::Builder(TestName() + suffix); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto slice_1 = builder.AddInstruction( + HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, + /*limit_indices=*/{1}, + /*strides=*/{1})); + auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + vec1_shape_, while_cond, while_body, slice_1)); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, while_inst}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat, + /*start_indices=*/{0}, + /*limit_indices=*/{1}, + /*strides=*/{1})); + return builder.Build(); + } + + // Create and return a trivial computation appropriate for use as a while + // condition. + std::unique_ptr MakeConditionComputation() { + auto builder = HloComputation::Builder(TestName() + ".cond"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + return builder.Build(); + } + + // Return the byte size of the top-level buffer of the given shape. + static int64 ByteSizeOf(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + } + + // Various shapes used in the canned computations. + const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); + const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); + const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024}); +}; + +// Test rematerialization of a single computation produced by +// MakeRematerializableComputation. +TEST_F(HloRematerializationTest, SingleComputation) { + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(MakeRematerializableComputation()); + + // Find and save the original broadcast instruction which should be + // rematerialized. + const HloInstruction* slice = computation->root_instruction(); + ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _))); + const HloInstruction* concat = slice->operand(0); + const HloInstruction* bcast = concat->operand(0); + + SequentialHloOrdering::HloModuleSequence sequence; + // Computation requires 16KB without rematerialization, but uses only 12KB + // with rematerialization so pick a memory limit between these values (14KB). + TF_ASSIGN_OR_ASSERT_OK( + bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/14 * 1024, module.get(), &sequence)); + EXPECT_TRUE(changed); + + // Root should not have changed. + EXPECT_EQ(computation->root_instruction(), slice); + + // The broadcast should have been rematerialized. + const HloInstruction* remat_bcast = concat->operand(0); + EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast))); + + // The rematerialized broadcast should be immediate before the concat in the + // sequence. + EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2], + concat); + EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3], + remat_bcast); +} + +// Test rematerialization of a single computation produced by +// MakeRematerializableComputation but with a sufficiently high memory limit +// such that no instructions are rematerialized. +TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(MakeRematerializableComputation()); + + EXPECT_EQ(computation->instruction_count(), 7); + + SequentialHloOrdering::HloModuleSequence sequence; + TF_ASSIGN_OR_ASSERT_OK( + bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/20 * 1024, module.get(), &sequence)); + + // No instructions should have been materialized. + EXPECT_FALSE(changed); + EXPECT_EQ(computation->instruction_count(), 7); +} + +// Test rematerialization of a computation which calls another computation via a +// while. Both the entry computation and while body computation can have memory +// usage reduced via rematerialization however the memory limit is set such that +// only one computation needs to have an instruction rematerialized. The entry +// computation should be the one chosen because rematerialization in the while +// will presumably be more expensive. +TEST_F(HloRematerializationTest, RematerializeAroundWhile) { + auto module = CreateNewModule(); + + auto cond_builder = HloComputation::Builder(TestName() + ".cond"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloComputation* while_cond = + module->AddEmbeddedComputation(cond_builder.Build()); + + HloComputation* body_computation = module->AddEmbeddedComputation( + MakeRematerializableComputation(/*suffix=*/".body")); + HloComputation* entry_computation = + module->AddEntryComputation(MakeRematerializableWhileComputation( + while_cond, /*while_body=*/body_computation)); + + EXPECT_EQ(entry_computation->instruction_count(), 6); + EXPECT_EQ(body_computation->instruction_count(), 7); + + // The body computation uses 16KB and the entry computation uses 2KB at the + // while so the peak memory use of the module is 18KB. Set the memory limit a + // bit lower (17KB) to force rematerialization of the entry computation. + SequentialHloOrdering::HloModuleSequence sequence; + TF_ASSIGN_OR_ASSERT_OK( + bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/17 * 1024, module.get(), &sequence)); + EXPECT_TRUE(changed); + + // Only the entry computation should have a rematerialized instruction added. + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 7); +} + +// Test rematerialization of a computation which calls another computation via a +// while. Both the entry computation and while body computation should have +// computations rematerialized. +TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { + auto module = CreateNewModule(); + + auto cond_builder = HloComputation::Builder(TestName() + ".cond"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloComputation* while_cond = + module->AddEmbeddedComputation(cond_builder.Build()); + + HloComputation* body_computation = module->AddEmbeddedComputation( + MakeRematerializableComputation(/*suffix=*/".body")); + HloComputation* entry_computation = + module->AddEntryComputation(MakeRematerializableWhileComputation( + while_cond, /*while_body=*/body_computation)); + + EXPECT_EQ(entry_computation->instruction_count(), 6); + EXPECT_EQ(body_computation->instruction_count(), 7); + + SequentialHloOrdering::HloModuleSequence sequence; + TF_ASSIGN_OR_ASSERT_OK( + bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/15 * 1024, module.get(), &sequence)); + EXPECT_TRUE(changed); + + // Both computations should have a rematerialized instruction added. + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); +} + +// Test rematerialization of a doubly nested computation. All computations +// should have an instruction rematerialized. +TEST_F(HloRematerializationTest, RematerializeNestedComputations) { + auto module = CreateNewModule(); + + auto cond_builder = HloComputation::Builder(TestName() + ".cond"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloComputation* while_cond = + module->AddEmbeddedComputation(cond_builder.Build()); + + HloComputation* inner_computation = module->AddEmbeddedComputation( + MakeRematerializableComputation(/*suffix=*/".inner")); + HloComputation* middle_computation = + module->AddEmbeddedComputation(MakeRematerializableWhileComputation( + while_cond, /*while_body=*/inner_computation, + /*suffix=*/".middle")); + HloComputation* entry_computation = + module->AddEntryComputation(MakeRematerializableWhileComputation( + while_cond, /*while_body=*/middle_computation)); + + EXPECT_EQ(entry_computation->instruction_count(), 6); + EXPECT_EQ(middle_computation->instruction_count(), 6); + EXPECT_EQ(inner_computation->instruction_count(), 7); + + // If all computations are maximally rematerialized then peak memory usage is + // ~12K so pick something slightly larger. + SequentialHloOrdering::HloModuleSequence sequence; + TF_ASSIGN_OR_ASSERT_OK( + bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/13 * 1024, module.get(), &sequence)); + EXPECT_TRUE(changed); + + // All computations should have a rematerialized instruction added. + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(middle_computation->instruction_count(), 7); + EXPECT_EQ(inner_computation->instruction_count(), 8); +} + +TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { + // Test that a single instruction is rematerialized several times. Module: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %add_1 = add(%bcast, bcast) + // F32[1024] %call_1 = call(Subcomputation, {%add_1}) + // F32[1024] %add_2 = add(%bcast, call_1) + // F32[1024] %call_2 = call(SubComputation, {%add_2}) + // F32[1024] %add_3 = add(%bcast, call_2) + // F32[1024] %call_3 = call(Subcomputation, {%add_3}) + // F32[1024] %add_4 = add(%bcast, call_3) + // + // Subcomputation: + // F32[1024] %param = {...} + // F32[2048] %concat = concat({%param, %param}) + // F32[1024] %slice = slice(%concat) + // + // The value %bcast is live across each call of Subcomputation (which requires + // 8KB) though the value is not used in the calls. Rematerializing %bcast + // across these calls reduces peak memory use from ~20KB down to ~16KB. + auto module = CreateNewModule(); + + HloComputation* subcomputation = nullptr; + { + auto builder = HloComputation::Builder(TestName() + ".subcomputation"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice( + vec1024_shape_, concat, /*start_indices=*/{0}, + /*limit_indices=*/{1024}, /*slices=*/{1})); + subcomputation = module->AddEmbeddedComputation(builder.Build()); + } + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); + auto call_1 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); + auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); + auto call_2 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation)); + auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_2)); + auto call_3 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation)); + auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_3)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto count_broadcasts = [](const HloComputation* computation) { + int64 bcast_count = 0; + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBroadcast) { + bcast_count++; + } + } + return bcast_count; + }; + + // Before rematerialization there should be a single broadcast instruction in + // the graph. + EXPECT_EQ(count_broadcasts(entry_computation), 1); + EXPECT_EQ(entry_computation->instruction_count(), 9); + + EXPECT_EQ(add_2->operand(0), bcast); + EXPECT_EQ(add_3->operand(0), bcast); + EXPECT_EQ(add_4->operand(0), bcast); + + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSIGN_OR_ASSERT_OK( + bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence)); + EXPECT_TRUE(changed); + + // The broadcast should have been rematerialized 3 times. + EXPECT_EQ(count_broadcasts(entry_computation), 4); + EXPECT_EQ(entry_computation->instruction_count(), 12); + + // The operands of add_2, add_3, and add_4 should all be rematerialized + // broadcasts. + EXPECT_NE(add_2->operand(0), bcast); + EXPECT_THAT(add_2->operand(0), op::Broadcast(param)); + EXPECT_NE(add_3->operand(0), bcast); + EXPECT_THAT(add_3->operand(0), op::Broadcast(param)); + EXPECT_NE(add_4->operand(0), bcast); + EXPECT_THAT(add_4->operand(0), op::Broadcast(param)); +} + +class IndirectUseTest : public HloRematerializationTest, + public ::testing::WithParamInterface {}; + +TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { + // Test that an rematerializable instruction is not rematerialized if it has + // an indirect use. Test is parameterized on whether the value has an indirect + // use, and the instruction should be rematerialized iff the value has no + // indirect use. Module: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %add_1 = add(%bcast, bcast) + // F32[1024] %call = call(Subcomputation, {%add_1}) + // F32[1024] %add_2 = add(%bcast, call) + // {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2) + // F32[1024] %gte = GetTupleElememt(%tuple, 0) + // F32[1024] %negate = negate(%gte) + // + // Subcomputation: + // F32[1024] %param = {...} + // F32[2048] %concat = concat({%param, %param}) + // F32[1024] %slice = slice(%concat) + // + // The value %bcast is live across the call and rematerialization of %bcast + // across that point would reduce peak memory use by 4KB. However, %bcast is + // used indirectly in the %negate so rematerialization should not happen. + // + // This test is parameterized on whether the broadcast has an indirect use or + // not. The indirect use is controlled by the index of the GetTupleElement + // instruction. If the element is 0, then the %negate operand aliases %bcast + // (ie %bcast is used indirectly by %negate), otherwise the %negate operand + // aliases %add_2. + const bool indirectly_used = GetParam(); + auto module = CreateNewModule(); + + HloComputation* subcomputation = nullptr; + { + auto builder = HloComputation::Builder(TestName() + ".subcomputation"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice( + vec1024_shape_, concat, /*start_indices=*/{0}, + /*limit_indices=*/{1024}, /*slices=*/{1})); + subcomputation = module->AddEmbeddedComputation(builder.Build()); + } + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); + auto call_1 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); + auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2})); + auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + vec1024_shape_, tuple, indirectly_used ? 0 : 1)); + builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(entry_computation->instruction_count(), 8); + + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSIGN_OR_ASSERT_OK( + bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence)); + // Rematerialization should only occur if the rematerializable instruction has + // no indirect uses. + if (indirectly_used) { + EXPECT_FALSE(changed); + EXPECT_EQ(entry_computation->instruction_count(), 8); + } else { + EXPECT_TRUE(changed); + EXPECT_EQ(entry_computation->instruction_count(), 9); + } +} + +INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest, + ::testing::Values(true, false)); + +} // namespace + +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 14800b53420..867ebc7f61a 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -66,13 +66,13 @@ class HloSubcomputationUnificationTest : public HloTestBase { }; TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { - auto hlo_module = MakeUnique("test_module"); + auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = - hlo_module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); + module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); auto callee2 = - hlo_module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); + module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); @@ -83,32 +83,31 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { builder.AddInstruction( HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y)); - hlo_module->AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, hlo_module->computations().size()); + EXPECT_EQ(3, module->computations().size()); EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + hlo_graph_dumper::DumpGraph(*module->entry_computation(), "before unification", false, false, nullptr); } - EXPECT_TRUE( - HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + hlo_graph_dumper::DumpGraph(*module->entry_computation(), "after unification", false, false, nullptr); } - EXPECT_EQ(2, hlo_module->computations().size()); + EXPECT_EQ(2, module->computations().size()); EXPECT_EQ(x->to_apply(), y->to_apply()); } TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { - auto hlo_module = MakeUnique("test_module"); + auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = - hlo_module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); + module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); auto callee2 = - hlo_module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); + module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); auto constant1 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); @@ -121,33 +120,32 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { builder.AddInstruction( HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y)); - hlo_module->AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, hlo_module->computations().size()); + EXPECT_EQ(3, module->computations().size()); EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + hlo_graph_dumper::DumpGraph(*module->entry_computation(), "before unification", false, false, nullptr); } - EXPECT_TRUE( - HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + hlo_graph_dumper::DumpGraph(*module->entry_computation(), "after unification", false, false, nullptr); } - EXPECT_EQ(2, hlo_module->computations().size()); + EXPECT_EQ(2, module->computations().size()); EXPECT_EQ(x->to_apply(), y->to_apply()); } // Do not unify subcomputations with different parameter shapes. TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { - auto hlo_module = MakeUnique("test_module"); + auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto callee1 = hlo_module->AddEmbeddedComputation( - CreateR1S32AdditionComputation(r1s32_5_)); - auto callee2 = hlo_module->AddEmbeddedComputation( - CreateR1S32AdditionComputation(r1s32_3_)); + auto callee1 = + module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_5_)); + auto callee2 = + module->AddEmbeddedComputation(CreateR1S32AdditionComputation(r1s32_3_)); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1s32_5_, "param1")); @@ -160,28 +158,27 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(S32, {8}), {x, y}, 0)); - hlo_module->AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, hlo_module->computations().size()); + EXPECT_EQ(3, module->computations().size()); EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + hlo_graph_dumper::DumpGraph(*module->entry_computation(), "before unification", false, false, nullptr); } - EXPECT_FALSE( - HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie()); + EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(), + hlo_graph_dumper::DumpGraph(*module->entry_computation(), "after unification", false, false, nullptr); } - EXPECT_EQ(3, hlo_module->computations().size()); + EXPECT_EQ(3, module->computations().size()); EXPECT_NE(x->to_apply(), y->to_apply()); } // Regression test for b/31466798. Checks that entry_computation is still valid // after unification. TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { - HloModule module(TestName()); + auto module = CreateNewModule(); for (int i = 0; i < 2; ++i) { HloComputation::Builder builder("pow"); auto x = @@ -191,15 +188,19 @@ TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kPower, x, y)); if (i == 0) { - module.AddEmbeddedComputation(builder.Build()); + module->AddEmbeddedComputation(builder.Build()); } else { - module.AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); } } - EXPECT_TRUE(HloSubcomputationUnification().Run(&module).ValueOrDie()); - EXPECT_EQ(1, module.computations().size()); - EXPECT_EQ(module.computations().front().get(), module.entry_computation()); + EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); + EXPECT_EQ(1, module->computations().size()); + EXPECT_EQ(module->computations().front().get(), module->entry_computation()); } } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc new file mode 100644 index 00000000000..6707b02c5c5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -0,0 +1,214 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +LIcensed under the Apache License, Version 2.0 (the "License"); +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; +using ::tensorflow::TensorShapeProto; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; +using ::tensorflow::str_util::Join; + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +string GetOpDefName(const HloInstruction* instruction) { + string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); + tensorflow::str_util::TitlecaseString(&name, "-"); + name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); + + if (instruction->opcode() == HloOpcode::kFusion) { + string fusion_name = ToString(instruction->fusion_kind()); + StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); + } + return name; +} + +TensorShapeProto GetTensorShape(const HloInstruction* instruction) { + TensorShapeProto tensor_shape; + const Shape& shape = instruction->shape(); + for (auto dim : shape.dimensions()) { + tensor_shape.add_dim()->set_size(dim); + } + return tensor_shape; +} + +} // namespace + +void CleanNodeName(string* name) { + name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); + const string chars_to_replace = "<>[]"; + auto pred = [&](char c) { + return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != + chars_to_replace.end(); + }; + std::replace_if(name->begin(), name->end(), pred, '_'); +} + +Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { + VLOG(2) << "Adding computation " << computation.name(); + for (auto embedded : computation.MakeEmbeddedComputationsList()) { + for (auto& instruction : embedded->instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + } + } + for (auto& instruction : computation.instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + } + return Status::OK(); +} + +const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; } + +const string& HloTfGraphBuilder::GetNodeNameForInstruction( + const HloInstruction* instruction) { + if (ContainsKey(instruction_to_node_name_, instruction)) { + return instruction_to_node_name_[instruction]; + } + string node_name; + // If an instruction is fused, put it in the subgraph of the fusion; + // otherwise, put it in the computation subgraph. + if (instruction->IsFused()) { + node_name = GetNodeNameForInstruction(instruction->fusion_instruction()); + } else { + node_name = instruction->parent()->name(); + if (!instruction->metadata().op_name().empty()) { + // Always make computations contain TF ops but not the other way around. + StrAppend(&node_name, "/", instruction->metadata().op_name()); + } + } + string instruction_name = instruction->name(); + if (instruction->opcode() == HloOpcode::kParameter) { + StrAppend(&instruction_name, ".", instruction->parameter_number()); + } + StrAppend(&node_name, "/", instruction_name); + CleanNodeName(&node_name); + auto ret = + instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); + CHECK(ret.second); + return ret.first->second; +} + +void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, + NodeDef* node_def) const { + auto& attrs = *node_def->mutable_attr(); + + // Set the number of arguments for instructions that have variadic operands. + if (HloOpcodeIsVariadic(instruction->opcode())) { + tensorflow::AttrValue attr_value; + attr_value.set_i(instruction->operands().size()); + attrs["arg_num"] = attr_value; + } + + // Set the node type. + attrs["type"].set_s( + xla::PrimitiveType_Name(instruction->shape().element_type())); + + // Set the framework op (e.g. Tensorflow op) that generated this XLA op. + attrs["tf_op_type"].set_s(instruction->metadata().op_type()); + attrs["tf_op_name"].set_s(instruction->metadata().op_name()); + + // Set the shape of the output tensor. "_output_shapes" is a special attribute + // name used by Tensorboard for shapes of output tensors. + tensorflow::AttrValue shapes; + *shapes.mutable_list()->add_shape() = GetTensorShape(instruction); + attrs["_output_shapes"] = shapes; + + // Set the layout. + if (LayoutUtil::HasLayout(instruction->shape())) { + string layout_string; + if (ShapeUtil::IsTuple(instruction->shape())) { + // For tuples, emit the full shape because the layout of a tuple is not + // represented in a single Layout field. + layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); + } else { + layout_string = StrCat( + "{", Join(instruction->shape().layout().minor_to_major(), ","), "}"); + } + attrs["layout"].set_s(layout_string); + } + + // Set op-specific attributes. + switch (instruction->opcode()) { + case HloOpcode::kConcatenate: + case HloOpcode::kBroadcast: + case HloOpcode::kReduce: + case HloOpcode::kReverse: + case HloOpcode::kTranspose: + for (auto dim : instruction->dimensions()) { + attrs["dims"].mutable_list()->add_i(dim); + } + break; + case HloOpcode::kGetTupleElement: + attrs["index"].set_i(instruction->tuple_index()); + break; + case HloOpcode::kRng: + attrs["dist"].set_s( + RandomDistribution_Name(instruction->random_distribution())); + break; + case HloOpcode::kConstant: + if (ShapeUtil::IsScalar(instruction->shape())) { + attrs["value"].set_s( + LiteralUtil::GetAsString(instruction->literal(), {})); + } + break; + case HloOpcode::kCustomCall: + attrs["custom_call_target"].set_s(instruction->custom_call_target()); + break; + default: + break; + } +} + +Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { + if (!visited_instructions_.insert(instruction).second) { + // Skip instructions that have already been added. + return Status::OK(); + } + + NodeDef* node_def = graph_def_.add_node(); + node_def->set_name(GetNodeNameForInstruction(instruction)); + node_def->set_op(GetOpDefName(instruction)); + SetNodeAttrs(instruction, node_def); + if (instruction->opcode() == HloOpcode::kFusion) { + for (auto& fused_instruction : instruction->fused_instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(fused_instruction.get())); + } + } + // Add all edges including control edges. + for (unsigned i = 0; i < instruction->operands().size(); ++i) { + *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i)); + } + // Called computations are control dependencies. + for (const auto* called_computation : instruction->called_computations()) { + *node_def->add_input() = StrCat( + "^", GetNodeNameForInstruction(called_computation->root_instruction())); + } + return Status::OK(); +} + +} // namespace hlo_graph_dumper +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h new file mode 100644 index 00000000000..b2c578af912 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace xla { +namespace hlo_graph_dumper { + +// This constructs a tensorflow graph for HLO computations. +class HloTfGraphBuilder { + public: + // Adds a computation to the graph. + Status AddComputation(const HloComputation& computation); + + const tensorflow::GraphDef& GetGraphDef() const; + + private: + // Gets the node name of an instruction. The node name is hierarchical. For + // example, if an instruction is fused, it will be put in a subgraph of the + // fusion instruction. + const string& GetNodeNameForInstruction(const HloInstruction* instruction); + + void SetNodeAttrs(const HloInstruction* instruction, + tensorflow::NodeDef* node_def) const; + + Status AddInstruction(const HloInstruction* instruction); + + tensorflow::GraphDef graph_def_; + // This records instructions that have been visited. + std::unordered_set visited_instructions_; + // A cache that maps instruction to the node name. + std::unordered_map instruction_to_node_name_; +}; + +// Cleans the node name to make it a valid name in a tensorflow graph. +void CleanNodeName(string* name); + +} // namespace hlo_graph_dumper +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc new file mode 100644 index 00000000000..c2718ea8003 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -0,0 +1,188 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +using ::tensorflow::GraphDef; + +class HloTfGraphBuilderTest : public HloTestBase { + protected: + HloTfGraphBuilderTest() {} + HloTfGraphBuilder generator_; + + // Create a computation which takes a scalar and returns its negation. + std::unique_ptr CreateNegateComputation() { + auto builder = HloComputation::Builder("Negate"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); + return builder.Build(); + } + + // Creates a computation which calls map with the given computation. + std::unique_ptr CreateMapComputation( + HloComputation *map_computation) { + auto builder = HloComputation::Builder("Map"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map_computation)); + return builder.Build(); + } + Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {}); +}; + +static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node, + const string &attr_name) { + auto attr = node.attr().find(attr_name); + CHECK(attr != node.attr().end()); + return attr->second; +} + +TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { + auto builder = HloComputation::Builder("Concatenate"); + Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1)); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + const auto &node = graph_def.node(2); + EXPECT_EQ(node.name(), "Concatenate/concatenate"); + + // Check dimensions. + auto dims_value = GetNodeAttr(node, "dims"); + EXPECT_EQ(dims_value.list().i_size(), 1); + EXPECT_EQ(dims_value.list().i(0), 1); + + // Check shapes. + auto shape_value = GetNodeAttr(node, "_output_shapes"); + EXPECT_EQ(shape_value.list().shape_size(), 1); + EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2); + EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2); + EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4); +} + +TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { + auto builder = HloComputation::Builder("Const"); + HloInstruction *instruction = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); + OpMetadata metadata; + metadata.set_op_name("x"); + metadata.set_op_type("y"); + instruction->set_metadata(metadata); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 1); + const auto &node = graph_def.node(0); + EXPECT_EQ(GetNodeAttr(node, "value").s(), "123"); + EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32"); + EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x"); + EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y"); +} + +TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) { + auto negate_computation = CreateNegateComputation(); + TF_CHECK_OK(generator_.AddComputation(*negate_computation)); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 2); + EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0"); + EXPECT_EQ(graph_def.node(0).op(), "HloParameter"); + EXPECT_EQ(graph_def.node(1).name(), "Negate/negate"); + EXPECT_EQ(graph_def.node(1).op(), "HloNegate"); + EXPECT_EQ(graph_def.node(1).input_size(), 1); + EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0"); +} + +TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { + auto builder = HloComputation::Builder("GE"); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "param1")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); + EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); + EXPECT_EQ(graph_def.node(2).input_size(), 2); + EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to"); + EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); +} + +TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) { + auto builder = HloComputation::Builder("GE"); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "param1")); + auto ge = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); + OpMetadata metadata; + metadata.set_op_name("x/y"); + metadata.set_op_type("Y"); + ge->set_metadata(metadata); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); + EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); + EXPECT_EQ(graph_def.node(2).input_size(), 2); + EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to"); + EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); +} + +TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { + // Create computations with a diamond-shaped callgraph. + auto negate_computation = CreateNegateComputation(); + auto map1_computation = CreateMapComputation(negate_computation.get()); + auto map2_computation = CreateMapComputation(negate_computation.get()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto map1 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); + auto map2 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); + auto computation = builder.Build(); + TF_CHECK_OK(generator_.AddComputation(*computation)); + EXPECT_GT(generator_.GetGraphDef().node_size(), 0); +} + +} // namespace +} // namespace hlo_graph_dumper +} // namespace xla + +int main(int argc, char **argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc new file mode 100644 index 00000000000..de6081e57e7 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" + +namespace xla { + +StatusOr HloVerifier::Run(HloModule* module) { + for (auto& computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RET_CHECK(instruction->parent() == computation.get()); + if (instruction->opcode() == HloOpcode::kFusion) { + for (const auto& fused : instruction->fused_instructions()) { + TF_RET_CHECK(fused->parent() == + instruction->fused_instructions_computation()) + << "Fused HLO was missing a parent: " << fused->ToString() + << " parent: " << fused->parent() + << " computation: " << computation.get(); + } + } + } + } + + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h new file mode 100644 index 00000000000..5159420b3fb --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// HLO pass that verifies invariants of HLO instructions for each computation in +// the module. +class HloVerifier : public HloPassInterface { + public: + ~HloVerifier() override = default; + tensorflow::StringPiece name() const override { return "verifier"; } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 0054edcf6ab..2887a8a0a09 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -22,13 +22,16 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -56,14 +59,14 @@ TEST_F(InlinerTest, MapMax) { HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); auto computation = builder.Build(); - auto hlo_module = MakeUnique("test_module"); + auto hlo_module = CreateNewModule(); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - HloInstruction* root = hlo_module->entry_computation()->root_instruction(); + Inliner inliner; EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); - root = hlo_module->entry_computation()->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kMaximum); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), + op::Maximum(lhs, rhs)); // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -90,14 +93,14 @@ TEST_F(InlinerTest, MapConstant) { HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); auto computation = builder.Build(); - auto hlo_module = MakeUnique("test_module"); + auto hlo_module = CreateNewModule(); hlo_module->AddEmbeddedComputation(std::move(const2_f32)); hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); Inliner inliner; EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + EXPECT_THAT(root, op::Broadcast(op::Constant())); // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -107,3 +110,7 @@ TEST_F(InlinerTest, MapConstant) { } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 42e33d53967..721640cdbd8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -29,7 +29,8 @@ limitations under the License. namespace xla { -bool IsExpensive(const HloInstruction& instruction) { +/*static*/ bool InstructionFusion::IsExpensive( + const HloInstruction& instruction) { switch (instruction.opcode()) { // Cheap instructions. case HloOpcode::kAbs: @@ -50,7 +51,7 @@ bool IsExpensive(const HloInstruction& instruction) { case HloOpcode::kGetTupleElement: case HloOpcode::kGt: case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: + case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLogicalAnd: case HloOpcode::kLogicalNot: @@ -61,6 +62,7 @@ bool IsExpensive(const HloInstruction& instruction) { case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: + case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: @@ -100,12 +102,18 @@ bool IsExpensive(const HloInstruction& instruction) { case HloOpcode::kRecv: return true; } + + return false; } -bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer) { - return !(producer->users().size() == 1 && - producer->users().count(consumer) == 1); +namespace { +// Returns true if fusing producer into consumer would cause producer to be +// duplicated. This is the case if producer has uses other than consumer. +bool FusionWouldDuplicate(const HloInstruction& producer, + const HloInstruction& consumer) { + return !(producer.users().size() == 1 && consumer.IsUserOf(&producer)); } +} // namespace StatusOr InstructionFusion::Run(HloModule* module) { bool changed = false; @@ -122,8 +130,54 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_->MakeInstructionPostOrder(); std::vector post_order(post_order_list.begin(), post_order_list.end()); + + std::set all_consumers_fusable; + // Find which ops can be fused into all of their operands. We would rather + // not fuse an op into only some of its users, as that offers no benefit in + // terms of memory bandwidth, but forces us to keep more live values around. + for (auto* hlo : post_order) { + auto user_fusable_into_hlo = [this, &hlo](HloInstruction* consumer) { + if (!consumer->IsFusable()) { + return false; + } + for (int operand_number = 0; + operand_number < consumer->operands().size(); ++operand_number) { + if (consumer->operand(operand_number) == hlo) { + if (!ShouldFuse(consumer, operand_number)) { + return false; + } + } + } + return true; + }; + + // An "effectively unary" operation is one that has one "large" + // input with the others being negligible in terms of memory usage. + // We use "has a smaller true rank than the output" as a heuristic + // for "negligible" memory usage. + auto effectively_unary = [](HloInstruction* hlo) { + if (hlo->operands().size() == 1) { + return true; + } + auto output_rank = ShapeUtil::TrueRank(hlo->shape()); + return std::count_if( + hlo->operands().begin(), hlo->operands().end(), + [output_rank](HloInstruction* operand) { + return ((operand->opcode() != HloOpcode::kBroadcast) && + ShapeUtil::TrueRank(operand->shape()) >= + output_rank); + }) <= 1; + }; + + if (effectively_unary(hlo) || + std::all_of(hlo->users().begin(), hlo->users().end(), + user_fusable_into_hlo)) { + all_consumers_fusable.insert(hlo); + } + } + tensorflow::gtl::FlatMap post_order_index; - for (int i = 0; i < post_order.size(); ++i) { + for (size_t i = 0; i < post_order.size(); ++i) { InsertOrDie(&post_order_index, post_order[i], i); } @@ -208,6 +262,12 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); + + if (FusionWouldDuplicate(*operand, *instruction) && + (all_consumers_fusable.count(operand) == 0)) { + continue; + } + if (operand->IsFusable() && ShouldFuse(instruction, i)) { HloInstruction* fusion_instruction = Fuse(operand, instruction); @@ -260,8 +320,8 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); // Cost condition: don't duplicate expensive instructions. - if (FusionWouldDuplicate(producer, consumer) && - (IsExpensive(*producer) || !may_duplicate_)) { + if (FusionWouldDuplicate(*producer, *consumer) && + (is_expensive_(*producer) || !may_duplicate_)) { return false; } @@ -274,7 +334,7 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, // Cost condition: not fuse (expensive producers) and (consumers who reuse // operand elements). if (consumer->ReusesOperandElements(operand_index) && - IsExpensive(*producer)) { + is_expensive_(*producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index b8fd3dd4f37..a9f3723f2df 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -24,15 +24,6 @@ limitations under the License. namespace xla { -// Returns true if the computation of the given instruction is significantly -// more expensive than just writing all the values of the instructions' result -// array. Expensive operations should not be duplicated. -bool IsExpensive(const HloInstruction& instruction); - -// Returns true if fusing producer into consumer would cause producer to be -// duplicated. This is the case if producer has uses other than consumer. -bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer); - // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -40,15 +31,22 @@ bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer); // instructions to fuse. class InstructionFusion : public HloPassInterface { public: - explicit InstructionFusion(bool may_duplicate = true) - : may_duplicate_(may_duplicate) {} - ~InstructionFusion() override {} + explicit InstructionFusion( + std::function is_expensive, + bool may_duplicate = true) + : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {} + ~InstructionFusion() override = default; tensorflow::StringPiece name() const override { return "fusion"; } // Run instruction fusion on the given computation. Returns whether the // computation was changed (instructions were fused). StatusOr Run(HloModule* module) override; + // Returns true if the computation of the given instruction is significantly + // more expensive than just writing all the values of the instructions' result + // array. Expensive operations will not be duplicated. + static bool IsExpensive(const HloInstruction& instruction); + protected: // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. @@ -74,6 +72,10 @@ class InstructionFusion : public HloPassInterface { private: HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + // Used to determine if an HLO is expensive. Expensive operations will not be + // duplicated. + std::function is_expensive_; + // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 2e3742ed75f..a2e6c2ae00b 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -15,8 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { using InstructionFusionTest = HloTestBase; @@ -32,11 +35,13 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), exp1, {0})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); EXPECT_EQ(broadcast2, computation->root_instruction()); } @@ -51,12 +56,14 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), negate1, {0})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Fusion()); } TEST_F(InstructionFusionTest, @@ -69,12 +76,14 @@ TEST_F(InstructionFusionTest, HloInstruction* reshape2 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Fusion()); } TEST_F(InstructionFusionTest, @@ -87,12 +96,14 @@ TEST_F(InstructionFusionTest, HloInstruction* transpose2 = builder.AddInstruction( HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_TRUE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Fusion()); } TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { @@ -102,11 +113,13 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); } TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { @@ -116,11 +129,13 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); } TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { @@ -130,11 +145,82 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {}), param0, {})); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose1, computation->root_instruction()); EXPECT_FALSE( - InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie()); + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { + HloComputation::Builder builder(TestName()); + auto shape = ShapeUtil::MakeShape(F32, {16, 16}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); + HloInstruction* binary1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + HloInstruction* unary = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(unary, computation->root_instruction()); + EXPECT_FALSE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, AllowUnaryDuplication) { + HloComputation::Builder builder(TestName()); + auto shape = ShapeUtil::MakeShape(F32, {16, 16}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); + HloInstruction* unary1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0)); + builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); + HloInstruction* unary2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(unary2, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { + auto shape = ShapeUtil::MakeShape(F32, {16, 16}); + auto small_shape = ShapeUtil::MakeShape(F32, {16}); + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, small_shape, "0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); + HloInstruction* binary1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + HloInstruction* unary = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(unary, computation->root_instruction()); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); } } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index a350acc4dae..e9e199226a6 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -60,8 +60,9 @@ std::ostream& operator<<(std::ostream& out, } BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, - const LogicalBuffer& buffer) - : layout_(layout), buffer_(&buffer) { + const LogicalBuffer& buffer, + bool mandatory) + : LayoutConstraint(mandatory), layout_(layout), buffer_(&buffer) { CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok()); } @@ -73,8 +74,9 @@ string BufferLayoutConstraint::ToString() const { OperandLayoutConstraint::OperandLayoutConstraint( const ShapeLayout& shape_layout, const HloInstruction* instruction, - int64 operand_no) - : shape_layout_(shape_layout), + int64 operand_no, bool mandatory) + : LayoutConstraint(mandatory), + shape_layout_(shape_layout), instruction_(instruction), operand_no_(operand_no) { CHECK(shape_layout_.LayoutIsSet()); @@ -100,7 +102,9 @@ LayoutConstraints::LayoutConstraints( : points_to_analysis_(points_to_analysis), computation_(computation) { // Gather all array-shaped logical buffers into unconstrained_buffer_ids. for (auto& buffer : points_to_analysis_.logical_buffers()) { - if (buffer->IsArray()) { + // The points to analysis is computed per module, restrict constraints to + // array buffers in this computation. + if (buffer->IsArray() && buffer->instruction()->parent() == computation) { unconstrained_buffer_ids_.insert(buffer->id()); } } @@ -115,15 +119,17 @@ bool LayoutConstraints::OperandBufferForwarded( auto operand_buffers = points_to_analysis_.GetPointsToSet(instruction->operand(operand_no)) .CreateFlattenedSet(); - std::vector intersection; - std::set_intersection(output_buffers.begin(), output_buffers.end(), - operand_buffers.begin(), operand_buffers.end(), - std::back_inserter(intersection)); - return !intersection.empty(); + for (const LogicalBuffer* output_buffer : output_buffers) { + if (operand_buffers.count(output_buffer) > 0) { + return true; + } + } + return false; } Status LayoutConstraints::SetBufferLayout(const Layout& layout, - const LogicalBuffer& buffer) { + const LogicalBuffer& buffer, + bool mandatory) { VLOG(3) << "SetBufferLayout : " << buffer << " : " << LayoutUtil::HumanString(layout); @@ -138,26 +144,38 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); - const Layout* curr_layout = BufferLayout(buffer); - if (curr_layout != nullptr) { - if (!LayoutUtil::Equal(*curr_layout, layout)) { + const BufferLayoutConstraint* curr_constraint = + GetBufferLayoutConstraint(buffer); + if (curr_constraint != nullptr) { + if (LayoutUtil::Equal(curr_constraint->layout(), layout)) { + // New constraint matches existing constraint. Nothing to do. + return Status::OK(); + } + if (curr_constraint->mandatory()) { return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", buffer.ToString().c_str(), - LayoutUtil::HumanString(*curr_layout).c_str(), + LayoutUtil::HumanString(curr_constraint->layout()).c_str(), LayoutUtil::HumanString(layout).c_str()); } - // New constraint matches existing constraint. Nothing to do. - return Status::OK(); } - auto new_constraint_it = buffer_constraints_.insert( - {&buffer, BufferLayoutConstraint(layout, buffer)}); - added_constraints_.push_back(&new_constraint_it.first->second); + auto iter = buffer_constraints_.find(&buffer); + bool overwrite = iter != buffer_constraints_.end(); + if (!overwrite) { + iter = buffer_constraints_ + .insert(std::make_pair( + &buffer, BufferLayoutConstraint(layout, buffer, mandatory))) + .first; + } else { + iter->second = BufferLayoutConstraint(layout, buffer, /*mandatory=*/true); + } + added_constraints_.push_back(&iter->second); // Remove buffer from the set of unconstrained buffers. - TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == 1); + TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == + static_cast(!overwrite)); unconstrained_buffer_ids_.erase(buffer.id()); return Status::OK(); @@ -165,23 +183,27 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, const HloInstruction* instruction, - int64 operand_no) { + int64 operand_no, bool mandatory) { VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand " << operand_no << " : " << ShapeUtil::HumanStringWithLayout(shape_with_layout); - const ShapeLayout* curr_shape_layout = OperandLayout(instruction, operand_no); + const OperandLayoutConstraint* curr_shape_layout = + GetOperandLayoutConstraint(instruction, operand_no); if (curr_shape_layout != nullptr) { - if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) { + if (curr_shape_layout->shape_layout().MatchesLayoutInShape( + shape_with_layout)) { + // New constraint matches existing constraint. Nothing to do. + return Status::OK(); + } + if (curr_shape_layout->mandatory()) { return FailedPrecondition( "Operand %lld of instruction %s already has a layout constraint " "%s, cannot add incompatible constraint %s", operand_no, instruction->name().c_str(), - curr_shape_layout->ToString().c_str(), + curr_shape_layout->shape_layout().ToString().c_str(), ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); } - // New constraint matches existing constraint. Nothing to do. - return Status::OK(); } // If any buffers in the operand occur in the output of the instruction, then @@ -195,22 +217,31 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, } auto key = std::make_pair(instruction, operand_no); - auto new_constraint_it = operand_constraints_.insert( - {key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction, - operand_no)}); - added_constraints_.push_back(&new_constraint_it.first->second); + auto iter = operand_constraints_.find(key); + if (iter == operand_constraints_.end()) { + auto pair = std::make_pair( + key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), + instruction, operand_no, mandatory)); + iter = operand_constraints_.insert(pair).first; + } else { + iter->second = + OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction, + operand_no, /*mandatory=*/true); + } + added_constraints_.push_back(&iter->second); return Status::OK(); } Status LayoutConstraints::SetArrayOperandLayout( - const Layout& layout, const HloInstruction* instruction, int64 operand_no) { + const Layout& layout, const HloInstruction* instruction, int64 operand_no, + bool mandatory) { const HloInstruction* operand = instruction->operand(operand_no); TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); Shape shape(operand->shape()); *shape.mutable_layout() = layout; TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); - return SetOperandLayout(shape, instruction, operand_no); + return SetOperandLayout(shape, instruction, operand_no, mandatory); } Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) { @@ -252,7 +283,7 @@ Status LayoutConstraints::SetInstructionLayout( // Create a BufferLayoutConstraint for each array shape in the output of the // instruction. - return ShapeUtil::ForEachSubshape( + return ShapeUtil::ForEachSubshapeWithStatus( shape_with_layout, [this, instruction](const Shape& subshape, const ShapeIndex& index) -> Status { @@ -273,15 +304,29 @@ Status LayoutConstraints::SetInstructionLayout( const Layout* LayoutConstraints::BufferLayout( const LogicalBuffer& buffer) const { + if (const auto* constraint = GetBufferLayoutConstraint(buffer)) { + return &constraint->layout(); + } + return nullptr; +} +const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint( + const LogicalBuffer& buffer) const { auto it = buffer_constraints_.find(&buffer); - return it == buffer_constraints_.end() ? nullptr : &it->second.layout(); + return it == buffer_constraints_.end() ? nullptr : &it->second; } const ShapeLayout* LayoutConstraints::OperandLayout( const HloInstruction* instruction, int64 operand_no) const { + if (const auto* constraint = + GetOperandLayoutConstraint(instruction, operand_no)) { + return &constraint->shape_layout(); + } + return nullptr; +} +const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint( + const HloInstruction* instruction, int64 operand_no) const { auto it = operand_constraints_.find(std::make_pair(instruction, operand_no)); - return it == operand_constraints_.end() ? nullptr - : &it->second.shape_layout(); + return it == operand_constraints_.end() ? nullptr : &it->second; } const ShapeLayout* LayoutConstraints::ResultLayout() const { @@ -298,8 +343,8 @@ string LayoutConstraints::ToString() const { for (int64 i = 0; i < instruction->operand_count(); ++i) { if (OperandLayout(instruction, i) != nullptr) { tensorflow::strings::StrAppend( - &output, " operand (", i, "): ", - OperandLayout(instruction, i)->ToString(), "\n"); + &output, " operand (", i, + "): ", OperandLayout(instruction, i)->ToString(), "\n"); } } for (const LogicalBuffer* buffer : @@ -338,6 +383,12 @@ Status LayoutAssignment::AddMandatoryConstraints( // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. shape_with_layout = &instruction->shape(); + } else if (instruction->opcode() == HloOpcode::kOutfeed) { + // Constrain the input to the Outfeed instruction to be the expected + // layout of the Outfeed. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + instruction->outfeed_shape(), instruction.get(), 0, + /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kParameter) { // Parameter layouts must match the respective layout in // ComputationLayout. @@ -369,7 +420,7 @@ Status LayoutAssignment::AddMandatoryConstraints( for (int64 i = 0; i < instruction->operand_count(); ++i) { TF_RETURN_IF_ERROR(constraints->SetOperandLayout( called_computation_layout.parameter_layout(i).shape(), - instruction.get(), i)); + instruction.get(), i, /*mandatory=*/true)); } } else if (instruction->opcode() == HloOpcode::kWhile) { // Layout of input and output of kWhile instruction must be equal and must @@ -420,7 +471,8 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( body_layout.result_shape(), instruction.get())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - body_layout.result_shape(), instruction.get(), 0)); + body_layout.result_shape(), instruction.get(), 0, + /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { // Add constraints for kCustomCall instruction operands and instructions. // For now we only support row major layouts for all inputs and outputs. @@ -444,7 +496,7 @@ Status LayoutAssignment::AddMandatoryConstraints( Shape row_major_operand_shape(row_major_shape(operand_shape)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction.get(), i)); + row_major_operand_shape, instruction.get(), i, /*mandatory=*/true)); } } } @@ -566,11 +618,11 @@ Status CheckLayouts( // which could be the source of the subshape value. const PointsToSet& points_to_set = points_to_analysis->GetPointsToSet(instruction.get()); - TF_RETURN_IF_ERROR(points_to_set.ForEachElement( + TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus( [&instruction]( - ShapeIndex index, bool is_leaf, + ShapeIndex index, const std::vector& buffers) -> Status { - if (is_leaf) { + if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) { const Shape& instruction_subshape = ShapeUtil::GetSubshape(instruction->shape(), index); for (const LogicalBuffer* buffer : buffers) { @@ -653,44 +705,6 @@ LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout) } } -namespace { - -// Given a pemutation of `{0, 1, ..., n}` `indices`, returns a permutation of -// `{0, 1, ..., n - to_delete.size() + to_insert.size()}` by deleting the -// indices `to_delete` wherever in `indices` they are, and inserting the indices -// `to_insert` arbitrarily at the back. -tensorflow::protobuf::RepeatedField -DeleteAndInsertIndices( - std::vector to_delete, std::vector to_insert, - tensorflow::protobuf::RepeatedField indices) { - std::sort(to_delete.begin(), to_delete.end(), std::greater()); - std::sort(to_insert.begin(), to_insert.end(), std::less()); - for (auto index : to_delete) { - auto i = indices.begin(); - while (i != indices.end()) { - if (*i == index) { - i = indices.erase(i); - } else { - if (*i > index) { - (*i)--; - } - ++i; - } - } - } - for (auto index : to_insert) { - for (auto i = indices.begin(); i != indices.end(); ++i) { - if (*i >= index) { - (*i)++; - } - } - indices.Add(index); - } - return indices; -} - -} // namespace - std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no) { @@ -713,21 +727,32 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( } if (instruction->opcode() == HloOpcode::kReshape) { - // Pick the operand layout that makes the reshape a bitcast. If the reshape - // only inserts or deletes degenerate dimensions, we can easily compute the - // desired layout by accordingly inserting and deleting the elements in the - // minor-to-major list. - bool merely_inserts_or_deletes_1_sized_dims; - std::vector inserted_indices, deleted_indices; - std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, - inserted_indices) = - instruction->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); - if (merely_inserts_or_deletes_1_sized_dims) { - Layout operand_layout = LayoutUtil::MakeLayout( - AsInt64Slice(DeleteAndInsertIndices(inserted_indices, deleted_indices, - output_layout.minor_to_major()))); + // Prefer the operand layout that makes the reshape an bitcast. If any + // dimension bound is 1 in the operand shape, there may be several such + // layouts. So if 'output_layout' is a MajorToMinor layout, try if the + // reshape is a bitcast when using the same layout. This may avoid copy + // operations. + const Shape& output_shape = instruction->shape(); + Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), + AsInt64Slice(output_layout.minor_to_major())); + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsMonotonicWithDim0Major(output_layout)) { + Shape operand_shape_with_layout = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + operand_shape.element_type(), + AsInt64Slice(operand_shape.dimensions())); + if (ShapeUtil::ReshapeIsBitcast(operand_shape_with_layout, + output_shape_with_layout)) { + return MakeUnique(operand_shape_with_layout.layout()); + } + } + auto aligned_operand_shape = + ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); + if (aligned_operand_shape) { + auto operand_layout = aligned_operand_shape.value().layout(); TF_CHECK_OK( - LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); + LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); return MakeUnique(operand_layout); } } @@ -762,18 +787,32 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } if (user->opcode() == HloOpcode::kReshape) { - // Pick the user layout that makes the reshape a bitcast. - bool merely_inserts_or_deletes_1_sized_dims; - std::vector inserted_indices, deleted_indices; - std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, - inserted_indices) = - user->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); - if (merely_inserts_or_deletes_1_sized_dims) { - Layout user_layout = LayoutUtil::MakeLayout(AsInt64Slice( - DeleteAndInsertIndices(deleted_indices, inserted_indices, - operand_layout.minor_to_major()))); + // Prefer the user layout that makes the reshape an bitcast. If any + // dimension bound is 1 in the user shape, there may be several such + // layouts. So if 'operand_layout' is a MajorToMinor layout, try if the + // reshape is a bitcast when using the same layout. This may avoid copy + // operations. + Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + operand->shape().element_type(), + AsInt64Slice(operand->shape().dimensions()), + AsInt64Slice(operand_layout.minor_to_major())); + const Shape& output_shape = user->shape(); + if (LayoutUtil::IsMonotonicWithDim0Major(operand_layout)) { + Shape output_shape_with_layout = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + output_shape.element_type(), + AsInt64Slice(output_shape.dimensions())); + if (ShapeUtil::ReshapeIsBitcast(output_shape_with_layout, + operand_shape_with_layout)) { + return MakeUnique(output_shape_with_layout.layout()); + } + } + auto aligned_user_shape = + ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); + if (aligned_user_shape) { + auto user_layout = aligned_user_shape.value().layout(); TF_CHECK_OK( - LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); + LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); return MakeUnique(user_layout); } } @@ -877,11 +916,11 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( // match the given layout. const PointsToSet& points_to_set = constraints->points_to_analysis().GetPointsToSet(instruction); - return points_to_set.ForEachElement( + return points_to_set.ForEachElementWithStatus( [this, &shape_layout, constraints]( - const ShapeIndex& index, bool is_leaf, + const ShapeIndex& index, const std::vector& buffers) -> Status { - if (is_leaf) { + if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { for (const LogicalBuffer* buffer : buffers) { if (constraints->BufferLayout(*buffer) == nullptr && ShapeUtil::IsArray(buffer->shape())) { @@ -930,7 +969,8 @@ Status LayoutAssignment::PropagateOperandConstraint( operand_constraint.shape_layout().layout(), user, operand_constraint.operand_no()); if (layout != nullptr) { - TF_RETURN_IF_ERROR(constraints->SetBufferLayout(*layout, *buffer)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(*layout, *buffer, /*mandatory=*/false)); } } return Status::OK(); @@ -960,11 +1000,19 @@ Status LayoutAssignment::PropagateBufferConstraint( instruction, operand_no); if (operand_layout != nullptr) { TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( - *operand_layout, instruction, operand_no)); + *operand_layout, instruction, operand_no, /*mandatory=*/true)); } } } } + return PropagateBufferConstraintToUses(buffer_constraint, constraints); +} + +Status LayoutAssignment::PropagateBufferConstraintToUses( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) { + const LogicalBuffer& buffer = buffer_constraint.buffer(); + TF_RET_CHECK(buffer.IsArray()); // Propagate the layout to all array uses of the logical buffer. This skips // uses of the buffer where the buffer is the element of a tuple. @@ -977,7 +1025,7 @@ Status LayoutAssignment::PropagateBufferConstraint( if (constraints->OperandLayout(user, operand_no) == nullptr && !constraints->OperandBufferForwarded(user, operand_no)) { TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( - buffer_constraint.layout(), user, operand_no)); + buffer_constraint.layout(), user, operand_no, /*mandatory=*/false)); } } @@ -1034,7 +1082,7 @@ StatusOr InferArrayLayout( *first_buffer_layout)) { // The points-to set is ambiguous for this index and the different source // buffers have different layouts. This case is possible in valid XLA - // computations because we do not propagate BufferLayoutConstaints to all + // computations because we do not propagate BufferLayoutConstraints to all // LogicalBuffers which may alias the constrained LogicalBuffer at some // point in the computation. return FailedPrecondition( @@ -1197,7 +1245,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, // Any remaining layouts in the output of the instruction must be // inferrable using points-to analysis. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshape( + TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( instruction->mutable_shape(), [instruction, &constraints](Shape* subshape, const ShapeIndex& index) { if (subshape->has_layout() || !ShapeUtil::IsArray(*subshape)) { @@ -1217,6 +1265,9 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, TF_RETURN_IF_ERROR(SetFusionLayouts(instruction)); } + // Execute extra verification step once the layout has been finalized. + TF_RETURN_IF_ERROR(Verify(instruction)); + // Verify all layouts in the shape have been set. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } @@ -1247,7 +1298,7 @@ Status LayoutAssignment::RunOnComputation( TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(computation->parent())); - // Construct LayoutConstaints with all layout constraints of the computation. + // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(*points_to_analysis, computation); // Add constraints required for correctness on all backends (eg, entry @@ -1272,7 +1323,8 @@ Status LayoutAssignment::RunOnComputation( const LogicalBuffer& buffer = points_to_analysis->GetBuffer( *constraints.unconstrained_buffer_ids().begin()); TF_RETURN_IF_ERROR(constraints.SetBufferLayout( - LayoutUtil::GetDefaultLayoutForShape(buffer.shape()), buffer)); + LayoutUtil::GetDefaultLayoutForShape(buffer.shape()), buffer, + /*mandatory=*/false)); TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 61dc7b12075..ccfc17da4c4 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -46,10 +46,16 @@ namespace xla { // gathered together in LayoutConstraints object. class LayoutConstraint { public: - LayoutConstraint() = default; + LayoutConstraint(bool mandatory) : mandatory_(mandatory) {} virtual ~LayoutConstraint() = default; virtual string ToString() const = 0; + + // True if this constraint cannot be overwritten by a different constraint. + bool mandatory() const { return mandatory_; } + + private: + bool mandatory_; }; std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); @@ -58,7 +64,8 @@ std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); // array produced by a particular instruction. class BufferLayoutConstraint : public LayoutConstraint { public: - BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer); + BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, + bool mandatory); const LogicalBuffer& buffer() const { return *buffer_; } const Layout& layout() const { return layout_; } @@ -66,7 +73,7 @@ class BufferLayoutConstraint : public LayoutConstraint { string ToString() const override; private: - const Layout layout_; + Layout layout_; const LogicalBuffer* buffer_; }; @@ -78,7 +85,8 @@ class BufferLayoutConstraint : public LayoutConstraint { class OperandLayoutConstraint : public LayoutConstraint { public: OperandLayoutConstraint(const ShapeLayout& shape_layout, - const HloInstruction* instruction, int64 operand_no); + const HloInstruction* instruction, int64 operand_no, + bool mandatory); const ShapeLayout& shape_layout() const { return shape_layout_; } const HloInstruction* instruction() const { return instruction_; } @@ -90,7 +98,7 @@ class OperandLayoutConstraint : public LayoutConstraint { string ToString() const override; private: - const ShapeLayout shape_layout_; + ShapeLayout shape_layout_; const HloInstruction* instruction_; int64 operand_no_; }; @@ -99,7 +107,7 @@ class OperandLayoutConstraint : public LayoutConstraint { class ResultLayoutConstraint : public LayoutConstraint { public: explicit ResultLayoutConstraint(const ShapeLayout& shape_layout) - : shape_layout_(shape_layout) {} + : LayoutConstraint(/*mandatory=*/true), shape_layout_(shape_layout) {} const ShapeLayout& shape_layout() const { return shape_layout_; } string ToString() const override; @@ -124,8 +132,7 @@ class LayoutConstraints { // Return a vector containing the constraints which have been added to the // LayoutConstraints object since the construction of the object or since the // last time ConsumeAddedConstraints() has been called. This is used to - // identify - // newly added constraints when propagating layouts. + // identify newly added constraints when propagating layouts. std::vector ConsumeAddedConstraints() { std::vector ret_vec(std::move(added_constraints_)); added_constraints_.clear(); @@ -137,23 +144,29 @@ class LayoutConstraints { // instruction, or the layout of the result of the computation, respectively, // if it has been constrained. Otherwise return nullptr. const Layout* BufferLayout(const LogicalBuffer& buffer) const; + const BufferLayoutConstraint* GetBufferLayoutConstraint( + const LogicalBuffer& buffer) const; const ShapeLayout* OperandLayout(const HloInstruction* instruction, int64 operand_no) const; + const OperandLayoutConstraint* GetOperandLayoutConstraint( + const HloInstruction* instruction, int64 operand_no) const; const ShapeLayout* ResultLayout() const; // Add a constraint on the layout of a LogicalBuffer, the layout of the // operand of the instruction, or the layout of the result of the computation, // respectively. - Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer); + Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, + bool mandatory = true); Status SetOperandLayout(const Shape& shape_with_layout, - const HloInstruction* instruction, int64 operand_no); + const HloInstruction* instruction, int64 operand_no, + bool mandatory = true); Status SetResultLayout(const Shape& shape_with_layout); // Convenience wrapper around SetOperandLayout for setting the layout of a // operand using a Layout object. The operand must be array-shaped. Status SetArrayOperandLayout(const Layout& layout, const HloInstruction* instruction, - int64 operand_no); + int64 operand_no, bool mandatory = true); // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers // created by the instruction to the layouts in the given shape. The @@ -233,6 +246,39 @@ class LayoutAssignment : public HloPassInterface { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); + // Called after layouts of an instruction have been finalized to allow + // subclasses to check for platform specific assumptions. + virtual Status Verify(const HloInstruction* instruction) { + return Status::OK(); + } + + // Propagates a buffer layout constraint into the operands that use it. + Status PropagateBufferConstraintToUses( + const BufferLayoutConstraint& layout_constraint, + LayoutConstraints* constraints); + + // Propagates a layout constraint on the use of the result of the given + // instruction to the definitions of the LogicalBuffers which make up the + // result. + Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout, + const HloInstruction* instruction, + LayoutConstraints* constraints); + + // Chooses a layout of operand `operand_no` of `instruction` that minimizes + // the cost of `instruction`. `output_layout` is the layout of `instruction`. + // Returns null if it can't decide the best layout. + // Precondition: `instruction` and the operand are array-shaped. + std::unique_ptr ChooseOperandLayoutFromOutputLayout( + const Layout& output_layout, const HloInstruction* instruction, + int64 operand_no); + // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of + // `user` that minimizes its cost on that operand. Returns null if it can't + // decide the best layout. + // Precondition: `user` and the operand are array-shaped. + std::unique_ptr ChooseOutputLayoutFromOperandLayout( + const Layout& operand_layout, const HloInstruction* user, + int64 operand_no); + private: // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. @@ -267,28 +313,6 @@ class LayoutAssignment : public HloPassInterface { // required for correctness. Status PropagateConstraints(LayoutConstraints* constraints); - // Propagates a layout constraint on the use of the result of the given - // instruction to the definitions of the LogicalBuffers which make up the - // result. - Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout, - const HloInstruction* instruction, - LayoutConstraints* constraints); - - // Chooses a layout of operand `operand_no` of `instruction` that minimizes - // the cost of `instruction`. `output_layout` is the layout of `instruction`. - // Returns null if it can't decide the best layout. - // Precondition: `instruction` and the operand are array-shaped. - std::unique_ptr ChooseOperandLayoutFromOutputLayout( - const Layout& output_layout, const HloInstruction* instruction, - int64 operand_no); - // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of - // `user` that minimizes its cost on that operand. Returns null if it can't - // decide the best layout. - // Precondition: `user` and the operand are array-shaped. - std::unique_ptr ChooseOutputLayoutFromOperandLayout( - const Layout& operand_layout, const HloInstruction* user, - int64 operand_no); - ComputationLayout* entry_computation_layout_; // Map containing the layouts of all computations assigned so diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 6361907b0e4..6d818cdea0c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -26,10 +26,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -38,9 +40,13 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { +using ::testing::ElementsAre; + class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, @@ -63,8 +69,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); Layout layout = LayoutUtil::MakeLayout(minor_to_major); Shape shape(ashape); @@ -75,7 +81,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -93,8 +99,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); Layout col_major_layout = LayoutUtil::MakeLayout({1, 0}); Shape col_major_shape(ashape); @@ -111,7 +117,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -142,8 +148,8 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( {negate2, negate1, add}, HloInstruction::FusionKind::kLoop); @@ -156,7 +162,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -191,13 +197,13 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { auto negate = builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kNegate, get_element0)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module.entry_computation()->ComputeProgramShape()); + module->entry_computation()->ComputeProgramShape()); - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); EXPECT_FALSE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -229,17 +235,17 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module.entry_computation()->ComputeProgramShape()); + module->entry_computation()->ComputeProgramShape()); Shape result_shape = ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()}); TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -264,11 +270,11 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { auto nested_tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, inner_tuple})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module.entry_computation()->ComputeProgramShape()); + module->entry_computation()->ComputeProgramShape()); Shape result_shape = nested_tuple->shape(); *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); @@ -278,7 +284,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -294,9 +300,9 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE( AlgebraicSimplifier(/*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return false; }) - .Run(&module) + .Run(module.get()) .ValueOrDie()); - HloInstruction* root = module.entry_computation()->root_instruction(); + HloInstruction* root = module->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape())); EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}), @@ -304,18 +310,16 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}), root->operand(1)->shape())); - // Verify some of the structure of the HLO graph. - EXPECT_EQ(constant, root->operand(0)->operand(0)); - EXPECT_EQ(HloOpcode::kCopy, root->operand(1)->operand(0)->opcode()); - EXPECT_EQ(HloOpcode::kConstant, - root->operand(1)->operand(0)->operand(0)->opcode()); + // Verify the structure of the HLO graph. + EXPECT_THAT(root, + op::Tuple(op::Tuple(constant), op::Tuple(op::Copy(constant)))); } TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { // param -> log -> reshape -> tanh auto builder = HloComputation::Builder(TestName()); Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1}); - Shape bshape = ShapeUtil::MakeShape(F32, {2, 1, 3}); + Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2}); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ashape, "param")); auto log = builder.AddInstruction( @@ -325,28 +329,29 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build(tanh)); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); - *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}); + *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3}); + *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); - EXPECT_LT(PositionInContainer(log_minor_to_major, 1), + EXPECT_GT(PositionInContainer(log_minor_to_major, 1), PositionInContainer(log_minor_to_major, 2)); auto reshape_minor_to_major = AsInt64Slice(reshape->shape().layout().minor_to_major()); - EXPECT_LT(PositionInContainer(reshape_minor_to_major, 0), + EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0), PositionInContainer(reshape_minor_to_major, 2)); } @@ -366,8 +371,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { HloInstruction::CreateTranspose(bshape, log, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build(tanh)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); @@ -378,7 +383,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -402,9 +407,9 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { HloInstruction::CreateBroadcast(bshape, param, {1, 2})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0})); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* computation = - module.AddEntryComputation(builder.Build(transpose)); + module->AddEntryComputation(builder.Build(transpose)); Shape input_shape_with_layout(ashape); Shape output_shape_with_layout(cshape); @@ -417,10 +422,10 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); - EXPECT_TRUE(ContainersEqual(broadcast->shape().layout().minor_to_major(), - tensorflow::gtl::ArraySlice{0, 1, 2})); + EXPECT_THAT(broadcast->shape().layout().minor_to_major(), + ElementsAre(0, 1, 2)); } TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { @@ -451,9 +456,9 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { HloInstruction::CreateBroadcast(f32_234, tanh, {2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation* computation = - module.AddEntryComputation(builder.Build(tuple)); + module->AddEntryComputation(builder.Build(tuple)); ComputationLayout computation_layout(computation->ComputeProgramShape()); Shape param_shape_with_layout(f32_4); @@ -470,17 +475,86 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(&module, &computation_layout); + AssignLayouts(module.get(), &computation_layout); - EXPECT_TRUE(ContainersEqual(broadcast->shape().layout().minor_to_major(), - tensorflow::gtl::ArraySlice{0, 1})); - EXPECT_TRUE(ContainersEqual(transpose->shape().layout().minor_to_major(), - tensorflow::gtl::ArraySlice{1, 0})); - EXPECT_TRUE(ContainersEqual(tanh->shape().layout().minor_to_major(), - tensorflow::gtl::ArraySlice{0, 1})); + EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); + EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); + EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1)); } -// Add test which fails due to copy tuple. +class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { + public: + explicit OperandsMustBeTheSameLayoutAssignment( + ComputationLayout* entry_computation_layout) + : LayoutAssignment(entry_computation_layout) {} + + protected: + Status PropagateBufferConstraint( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) override { + const LogicalBuffer& buffer = buffer_constraint.buffer(); + const HloInstruction* instruction = buffer.instruction(); + + // Force the operands' layout to the output layout. + for (int64 operand_no = 0; operand_no < instruction->operand_count(); + ++operand_no) { + const HloInstruction* operand = instruction->operand(operand_no); + if (ShapeUtil::Rank(instruction->shape()) != + ShapeUtil::Rank(operand->shape())) { + continue; + } + TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( + buffer_constraint.layout(), instruction, operand_no, + /*mandatory=*/true)); + } + return PropagateBufferConstraintToUses(buffer_constraint, constraints); + } +}; + +TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { + // param0 -> concatenate -> reshape + // param1 -^ + auto builder = HloComputation::Builder(TestName()); + Shape ashape = ShapeUtil::MakeShape(F32, {50, 1}); + Shape bshape = ShapeUtil::MakeShape(F32, {50, 2}); + Shape cshape = ShapeUtil::MakeShape(F32, {100}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "param")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, ashape, "param")); + auto concatenate = builder.AddInstruction( + HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1)); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(cshape, concatenate)); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(reshape)); + + Shape param0_shape_with_layout(ashape); + Shape param1_shape_with_layout(ashape); + *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(param0_shape_with_layout); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(param1_shape_with_layout); + OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); + EXPECT_IS_OK(layout_assignment.Run(module.get()).status()); + + EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); + EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), + ElementsAre(1, 0)); + EXPECT_THAT(concatenate->operand(1)->shape().layout().minor_to_major(), + ElementsAre(1, 0)); + EXPECT_THAT(concatenate->shape().layout().minor_to_major(), + ElementsAre(1, 0)); +} } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc new file mode 100644 index 00000000000..682bf19807b --- /dev/null +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -0,0 +1,225 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/liveness_util.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + return (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()); +} + +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, + const TuplePointsToAnalysis& points_to_analysis) { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { + // GetTupleElement instructions only access the top-level buffer of their + // operand. + return true; + } else if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + auto it = std::find_if( + user->fused_parameters().begin(), user->fused_parameters().end(), + [=](HloInstruction* fused_param) { + return user->operand(fused_param->parameter_number()) == operand; + }); + CHECK(it != user->fused_parameters().end()); + // Iterate through all users of all buffer aliases of the buffer in the + // points-to set of fusion parameter at 'index'. + // Return false if any uses are detected at 'index', returns true otherwise. + const LogicalBuffer* buffer = + points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user, points_to_analysis)) { + continue; + } + // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. + return false; + } + } + // Return true: found no uses of 'operand' at 'index' in 'user'. + return true; + } + return false; +} + +namespace { + +// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. +// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) +// where 'user' is a user of an alias of 'intruction' at 'index', and +// 'operand_index' is the operand index at which the alias appears in the +// operand list of 'user'. +std::vector> GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis) { + std::vector> uses; + const std::vector& points_to = + points_to_analysis.GetPointsToSet(instruction).element(index); + for (const LogicalBuffer* buffer : points_to) { + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user, points_to_analysis)) { + continue; + } + for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { + uses.emplace_back(alias_user, op_idx); + } + } + } + } + return uses; +} + +// Returns true if there is exactly one use of 'operand' at 'operand_index' +// in 'fusion.fused_instructions', where the singleton use is the fused +// root at operand index 'use_operand_index'. Returns false otherwise. +// +// REQUIRES: 'fusion' opcode is a kFusion instruction. +bool HasUniqueFusedUseOfOperandAt( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* fusion, const int64 use_operand_index, + const TuplePointsToAnalysis& points_to_analysis) { + CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); + // Check that 'operand' is unique in the operand list of 'fusion'. + if (fusion->OperandIndices(operand).size() > 1) { + return false; + } + // Find fusion parameter associated with 'operand'. + const auto& fused_params = fusion->fused_parameters(); + auto fused_param_it = std::find_if( + fused_params.begin(), fused_params.end(), + [&](HloInstruction* fused_param) { + return fusion->operand(fused_param->parameter_number()) == operand; + }); + if (fused_param_it == fused_params.end()) { + return false; + } + auto* fused_param = *fused_param_it; + // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. + auto fused_param_uses = GetAllUsesOfInstructionAtIndex( + fused_param, operand_index, points_to_analysis); + // Return true iff there is exactly one use of 'operand' at 'index', and + // this singleton use is the fused root (at index in 'use_operand_indices'). + return fused_param_uses.size() == 1 && + fused_param_uses[0].first == fusion->fused_expression_root() && + fused_param_uses[0].second == use_operand_index; +} + +} // namespace + +// User and operand can share buffers iff both instructions emit the same shape +// and layout, and 'user' meets one of the following qualifications: +// *) Is element-wise. Or... +// *) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. Or... +// *) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion +// instruction where the only use of 'operand' at 'index' in the set +// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... +// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0. +bool CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index, + const TuplePointsToAnalysis& points_to_analysis) { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + Shape operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + if (user->opcode() == HloOpcode::kFusion) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, + points_to_analysis); + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is either kDot, or nested + // kFusion of kind kTransposeDot. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kDot || + (operand->opcode() == HloOpcode::kFusion && + operand->fusion_kind() == + HloInstruction::FusionKind::kTransposeDot); + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, + other_add_operand_index, + points_to_analysis); + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + // Check if 'user' is element-wise. + return user->IsElementwise(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h new file mode 100644 index 00000000000..0b01223db73 --- /dev/null +++ b/tensorflow/compiler/xla/service/liveness_util.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A collection of utilities on the HLO graph. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// Returns true if 'user' cannot possibly use the buffer at 'index' in +// 'operand'. Returns false otherwise. +// +// REQUIRES: 'operand' is an operand of 'user'. +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, + const TuplePointsToAnalysis& points_to_analysis); + +// Overload which does not require points-to analysis. The result is more +// conservative (returns false more often). +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user); + +// Returns true if 'user' (at 'user_index') can share a buffer with its operand +// 'operand' (at 'operand_index'). +// Returns false otherwise. +// +// REQUIRES: 'operand' is an operand of 'user'. +bool CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index, + const TuplePointsToAnalysis& points_to_analysis); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc new file mode 100644 index 00000000000..bad4be149a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -0,0 +1,372 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/liveness_util.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +class PointsToAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr points_to_analysis_; +}; + +class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_)); + EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_)); + EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_)); + EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE( + DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_)); + EXPECT_FALSE( + DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_)); +} + +class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); + EXPECT_TRUE( + CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, + *points_to_analysis_)); + EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); + EXPECT_TRUE( + CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, + *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + auto b_t = builder.AddInstruction( + HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + + auto nested_fusion = computation_->CreateFusionInstruction( + {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + + auto fusion = computation_->CreateFusionInstruction( + {add, nested_fusion}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused transpose-dot-add should be share buffer with 'add_operand'. + EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 1edfec4dae5..12b2762f0ed 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -48,8 +48,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", - "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -80,6 +78,7 @@ cc_library( deps = [ ":ir_array", ":llvm_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/llvm_ir/README.md b/tensorflow/compiler/xla/service/llvm_ir/README.md index 9fe7152477f..9e4cdd45dca 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/README.md +++ b/tensorflow/compiler/xla/service/llvm_ir/README.md @@ -1,2 +1,2 @@ -Common utilites and abstractions for handling and emitting LLVM IR for XLA +Common utilities and abstractions for handling and emitting LLVM IR for XLA backends. diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index a552ea0218a..02710ff57f6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -26,35 +26,41 @@ limitations under the License. namespace xla { namespace llvm_ir { +// Sentry allocation used to represent parameters of the entry computation in +// alias_scope_metadata_ and noalias_metadata_. +static const BufferAllocation* kParameterAllocation = new BufferAllocation( + /*index=*/-1, /*size=*/0, /*is_thread_local=*/false, /*is_reusable=*/false, + LogicalBuffer::Color(0)); + void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, llvm_ir::IrArray* array) { - BufferAllocation::Index buffer_index; + BufferAllocation::Slice buffer_slice; if (hlo.opcode() == HloOpcode::kParameter) { // Parameters may alias with each other but may not alias with our temporary // buffers. - buffer_index = kParameterAliasSet; + buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0); } else { - const std::set allocations = - assignment_.GetAllocations(&hlo, /*index=*/{}); - if (allocations.empty() || allocations.size() > 1) { - // Skip HLOs which don't have buffers a buffer assigned or for which the + const std::set slices = + assignment_.GetAllSlices(&hlo, /*index=*/{}); + if (slices.empty() || slices.size() > 1) { + // Skip HLOs which don't have a buffer assigned or for which the // buffer can't be determined statically. We cannot determine their // aliasing properties in these cases. return; } - buffer_index = allocations.begin()->index(); + buffer_slice = *slices.begin(); } - llvm::MDNode*& alias_scope_md = alias_scope_metadata_[buffer_index]; + llvm::MDNode*& alias_scope_md = alias_scope_metadata_[buffer_slice]; if (alias_scope_md == nullptr) { alias_scope_md = - GetAliasScopeMetadataForBuffer(buffer_index, GetAliasDomain()); + GetAliasScopeMetadataForBuffer(buffer_slice, GetAliasDomain()); } array->AddAliasScopeMetadata(alias_scope_md); - llvm::MDNode*& noalias_md = noalias_metadata_[buffer_index]; + llvm::MDNode*& noalias_md = noalias_metadata_[buffer_slice]; if (noalias_md == nullptr) { - noalias_md = GetNoaliasMetadataForBuffer(buffer_index, GetAliasDomain(), + noalias_md = GetNoaliasMetadataForBuffer(buffer_slice, GetAliasDomain(), assignment_, hlo); } array->AddNoaliasMetadata(noalias_md); @@ -80,7 +86,7 @@ llvm::MDNode* AliasAnalysis::GetAliasDomain() { } llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( - BufferAllocation::Index buffer_index, llvm::MDNode* domain) { + const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain) { legacy_flags::AliasAnalysisFlags* flags = legacy_flags::GetAliasAnalysisFlags(); if (!flags->xla_emit_alias_scope) { @@ -89,20 +95,19 @@ llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( // While we could synthesize an alias.scope, doing so is not more profitable // than LLVM's default behavior. - if (buffer_index == kParameterAliasSet) { + if (buffer_slice.allocation() == kParameterAllocation) { return nullptr; } llvm::MDBuilder metadata_builder(domain->getContext()); llvm::MDNode* scope = metadata_builder.createAliasScope( - AsStringRef(tensorflow::strings::StrCat("buffer: ", buffer_index)), - domain); + AsStringRef("buffer: " + buffer_slice.ToString()), domain); llvm::MDNode* scope_list = llvm::MDNode::get(domain->getContext(), scope); return scope_list; } llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( - BufferAllocation::Index buffer_index, llvm::MDNode* domain, + const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain, const BufferAssignment& assignment, const HloInstruction& hlo) { legacy_flags::AliasAnalysisFlags* flags = legacy_flags::GetAliasAnalysisFlags(); @@ -147,18 +152,20 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( add_buffers_to_worklist(operand); } - std::unordered_set buffers; + tensorflow::gtl::FlatSet + buffers; for (const LogicalBuffer* buffer : worklist) { // Skip buffers which cannot be added to the noalias set. if (!assignment.HasAllocation(*buffer) || buffer->instruction()->opcode() == HloOpcode::kParameter) { continue; } - BufferAllocation::Index noalias_index = - assignment.GetAssignedAllocation(*buffer).index(); - // Our buffer must not noalias itself. - if (noalias_index != buffer_index) { - buffers.insert(noalias_index); + const BufferAllocation::Slice noalias_slice = + assignment.GetAssignedAllocation(*buffer).GetSlice(*buffer); + // Our buffer must not overlap with the noalias slice. + if (!buffer_slice.OverlapsWith(noalias_slice)) { + buffers.insert(noalias_slice); // Some instructions have too many operands, causing the noalias set to be // too large. To reduce compilation time (b/31901575), truncate noalias // sets to at most 500 elements. @@ -180,10 +187,9 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( llvm::MDBuilder metadata_builder(domain->getContext()); std::vector scopes; - for (BufferAllocation::Index noalias_index : buffers) { + for (const BufferAllocation::Slice noalias_slice : buffers) { llvm::MDNode* scope = metadata_builder.createAliasScope( - AsStringRef(tensorflow::strings::StrCat("buffer: ", noalias_index)), - domain); + AsStringRef("buffer: " + noalias_slice.ToString()), domain); scopes.push_back(scope); } llvm::MDNode* noalias_list = diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index d8d45dd49b3..9eb1cbaa341 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ -#include - #include "external/llvm/include/llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -44,20 +44,20 @@ class AliasAnalysis { // Returns a unique alias domain for this emitter. llvm::MDNode* GetAliasDomain(); - // Returns an alias.scope metadata node corresponding to a given buffer index. + // Returns an alias.scope metadata node corresponding to a given buffer slice. llvm::MDNode* GetAliasScopeMetadataForBuffer( - BufferAllocation::Index buffer_index, llvm::MDNode* domain); + const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain); - // Returns a noalias metadata node corresponding to a given buffer index. + // Returns a noalias metadata node corresponding to a given buffer slice. // - // |buffer_index| is the buffer index. + // |buffer_slice| is the buffer slice. // // |domain| corresponds to the alias scope domain as documented at // http://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata // // |hlo| is the instruction we are computing a noalias set for. llvm::MDNode* GetNoaliasMetadataForBuffer( - BufferAllocation::Index buffer_index, llvm::MDNode* domain, + const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain, const BufferAssignment& assignment, const HloInstruction& hlo); // The HLO module we are compiling for. @@ -73,18 +73,18 @@ class AliasAnalysis { // Holds the alias domain for this computation. llvm::MDNode* alias_domain_ = nullptr; - // Index in alias_scope_metadata_ and noalias_metadata_ for parameters - // of the entry computation which have special aliasing properties. - static constexpr int kParameterAliasSet = -1; - - // A map from a buffer index to metadata corresponding to its alias.scope + // A map from a buffer slice to metadata corresponding to its alias.scope // metadata. The index kParameterAliasSet is used to hold aliasing // information for parameters. - std::unordered_map alias_scope_metadata_; + tensorflow::gtl::FlatMap + alias_scope_metadata_; - // A map from a buffer index to metadata corresponding to its noalias + // A map from a buffer slice to metadata corresponding to its noalias // metadata. - std::unordered_map noalias_metadata_; + tensorflow::gtl::FlatMap + noalias_metadata_; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 303bb3ee6b9..79007b7099a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -62,6 +62,13 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { // Returns the generator function for the given instruction. Generator GetGenerator(const HloInstruction* instruction) const; + // Returns the ir value for instruction 'hlo'. + llvm::Value* GetIrValueForGTE(const HloInstruction* hlo) const { + auto it = gte_values_.find(hlo); + CHECK(it != gte_values_.end()); + return it->second; + } + private: // Arrays of parameters of fusion instruction tensorflow::gtl::ArraySlice parameter_arrays_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 38728d2e1f3..e401305ae73 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -236,10 +236,8 @@ llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::LoadInst* load = ir_builder->CreateLoad(element_address); llvm_ir::SetTbaaForInstruction(load, GetShape(), /*is_pointer_to=*/false); - for (const std::pair& kind_md_pair : metadata_) { - int kind = kind_md_pair.first; - llvm::MDNode* md = kind_md_pair.second; - load->setMetadata(kind, md); + for (const auto& kind_md_pair : metadata_) { + load->setMetadata(kind_md_pair.first, kind_md_pair.second); } return load; } @@ -250,11 +248,9 @@ void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, llvm::StoreInst* store = ir_builder->CreateStore(value, element_address); llvm_ir::SetTbaaForInstruction(store, GetShape(), /*is_pointer_to=*/false); - for (const std::pair& kind_md_pair : metadata_) { - int kind = kind_md_pair.first; - CHECK_NE(kind, llvm::LLVMContext::MD_invariant_load); - llvm::MDNode* md = kind_md_pair.second; - store->setMetadata(kind, md); + for (const auto& kind_md_pair : metadata_) { + CHECK_NE(kind_md_pair.first, llvm::LLVMContext::MD_invariant_load); + store->setMetadata(kind_md_pair.first, kind_md_pair.second); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 4ccded61e73..97f1b8ce308 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -22,6 +22,7 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Function.h" #include "external/llvm/include/llvm/IR/Instructions.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 0cc82b040d2..60ac0444bcd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -47,22 +47,22 @@ class ForLoop { // created exit basic block. Instructions before the insert point remain in // the insert BB: // - // /--------------\ /----------------\ + // +--------------+ +----------------+ // | insert BB | | insert BB | // | ... | | (preheader BB) | // | %foo = ... | | ... | // insert point ->| %bar = ... | ===> | %foo = ... | - // | ... | \----------------/ - // \--------------/ | + // | ... | +----------------+ + // +--------------+ | // V // [[ LOOP BBs ]] // | // V - // /--------------\ + // +--------------+ // | exit BB | // | %bar = ... | // | ... | - // \--------------/ + // +--------------+ // // `suffix` is a string used to disambiguate variable and basic block names // emitted in LLVM IR. This string is appended to the name of the induction @@ -82,31 +82,31 @@ class ForLoop { // do_stuff(i); // } // - // /--------------\ + // +--------------+ // | preheader BB | // | i = 0 | - // \--------------/ + // +--------------+ // | // V - // /-------------\ + // +-------------+ // | header BB |<-+ // | if i < n: | | // | goto body | | // | else: | | // | goto exit | | - // \-------------/ | + // +-------------+ | // | | | // +--------+ | | // | V | - // | /-------------\ | + // | +-------------+ | // | | body BB | | // | | dostuff(i) |--+ // | | ++i | - // | \-------------/ + // | +-------------+ // | - // | /-------------\ + // | +-------------+ // +->| exit BB | - // \-------------/ + // +-------------+ // // Caller-emitted code to execute within the loop should be placed within the // "body" basic block. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 99d0d0e1c42..ff2f4cd693c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -449,25 +449,23 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { return ShapeUtil::ByteSizeOf(shape, pointer_size); } -llvm::FastMathFlags GetFastMathFlags(const HloModuleConfig& config) { +llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) { llvm::FastMathFlags flags; - if (!config.fast_math_disabled()) { + if (fast_math_enabled) { // UnsafeAlgebra implies NoInfs, NoNaNs, NoSignedZeros, and AllowReciprocal. flags.setUnsafeAlgebra(); } return flags; } -void SetTargetOptions(const HloModuleConfig& config, +void SetTargetOptions(bool fast_math_enabled, llvm::TargetOptions* target_options) { - bool fast = !config.fast_math_disabled(); // In LLVM backend flags, UnsafeFPMath does not explicitly imply // NoInfs, etc. - target_options->UnsafeFPMath = fast; - target_options->NoInfsFPMath = fast; - target_options->NoNaNsFPMath = fast; - target_options->NoSignedZerosFPMath = fast; - target_options->LessPreciseFPMADOption = fast; + target_options->UnsafeFPMath = fast_math_enabled; + target_options->NoInfsFPMath = fast_math_enabled; + target_options->NoNaNsFPMath = fast_math_enabled; + target_options->NoSignedZerosFPMath = fast_math_enabled; } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 28488ca9991..7b09c1f8314 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -27,8 +27,7 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Module.h" #include "external/llvm/include/llvm/IR/Value.h" #include "external/llvm/include/llvm/Support/raw_ostream.h" -#include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -130,7 +129,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder, int alignment = 0); -// Creates a basic block with the same context and funtion as for the +// Creates a basic block with the same context and function as for the // builder. Inserts at the end of the function if insert_before is // null. llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, @@ -219,11 +218,11 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout); // Gets an llvm::FastMathFlags that reflects the settings in the given // module config. -llvm::FastMathFlags GetFastMathFlags(const HloModuleConfig& config); +llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled); // Sets values in the given TargetOptions struct according to the given // compilation options. -void SetTargetOptions(const HloModuleConfig& config, +void SetTargetOptions(bool fast_math_enabled, llvm::TargetOptions* target_options); } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 30bf450c5b1..131c2ee87b0 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -44,65 +46,6 @@ namespace se = ::perftools::gputools; namespace xla { -LocalExecuteOptions& LocalExecuteOptions::set_platform( - perftools::gputools::Platform* platform) { - platform_ = platform; - return *this; -} - -perftools::gputools::Platform* LocalExecuteOptions::platform() const { - return platform_; -} - -LocalExecuteOptions& LocalExecuteOptions::set_device_ordinal( - int device_ordinal) { - device_ordinal_ = device_ordinal; - return *this; -} - -int LocalExecuteOptions::device_ordinal() const { return device_ordinal_; } - -LocalExecuteOptions& LocalExecuteOptions::set_allocator( - DeviceMemoryAllocator* allocator) { - allocator_ = allocator; - return *this; -} - -DeviceMemoryAllocator* LocalExecuteOptions::allocator() const { - return allocator_; -} - -LocalExecuteOptions& LocalExecuteOptions::set_stream( - perftools::gputools::Stream* stream) { - stream_ = stream; - return *this; -} - -perftools::gputools::Stream* LocalExecuteOptions::stream() const { - return stream_; -} - -LocalExecuteOptions& LocalExecuteOptions::set_execution_profile( - ExecutionProfile* profile) { - profile_ = profile; - return *this; -} - -ExecutionProfile* LocalExecuteOptions::execution_profile() const { - return profile_; -} - -LocalExecuteOptions& LocalExecuteOptions::set_result_layout( - const Shape& shape_with_layout) { - has_result_shape_with_layout_ = true; - result_shape_with_layout_ = shape_with_layout; - return *this; -} - -const Shape* LocalExecuteOptions::result_layout() const { - return has_result_shape_with_layout_ ? &result_shape_with_layout_ : nullptr; -} - /* static */ StatusOr> LocalService::NewService( perftools::gputools::Platform* platform) { ServiceOptions default_options; @@ -117,9 +60,12 @@ const Shape* LocalExecuteOptions::result_layout() const { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr backend, - Backend::CreateBackend(platform, options.number_of_replicas())); + BackendOptions backend_options; + backend_options.set_platform(platform) + .set_number_of_replicas(options.number_of_replicas()) + .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()); + TF_ASSIGN_OR_RETURN(std::unique_ptr backend, + Backend::CreateBackend(backend_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); @@ -134,21 +80,6 @@ LocalService::LocalService(std::unique_ptr execute_backend, runs_in_client_process_ = true; } -tensorflow::Status LocalService::ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs) { - TF_ASSIGN_OR_RETURN(std::vector arg_allocations, - ResolveAndValidateArguments( - arguments, execute_backend_.get(), device_ordinal)); - argument_ptrs->resize(arg_allocations.size()); - for (int i = 0; i < arguments.size(); ++i) { - const Allocation& allocation = *arg_allocations[i]; - (*argument_ptrs)[i] = allocation.device_memory(); - } - return tensorflow::Status::OK(); -} - namespace { // Returns the space required to allocate a shape. If // allocate_space_for_deep_copy the space includes all sub-buffers of @@ -159,12 +90,11 @@ int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy, // TODO(b/33492279) remove once no devices represent result tuples as // contiguous buffers. if (allocate_space_for_deep_copy) { - TF_CHECK_OK(ShapeUtil::ForEachSubshape( + ShapeUtil::ForEachSubshape( shape, [&size, transfer_manager](const Shape& subshape, const ShapeIndex& /*index*/) { size += transfer_manager->GetByteSizeRequirement(subshape); - return tensorflow::Status::OK(); - })); + }); } return size; } @@ -185,302 +115,6 @@ StatusOr LocalService::AllocateBufferOnDevice( allocation_size)); } -StatusOr> LocalService::ExecuteLocally( - const ComputationHandle& computation, - tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options) { - return ExecuteLocallyInternal(computation, arguments, options, - /*preallocated_result_buffer=*/nullptr); -} - -tensorflow::Status LocalService::ExecuteLocally( - const ComputationHandle& computation, - tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options, ShapedBuffer* result_buffer) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr null_buffer, - ExecuteLocallyInternal(computation, arguments, options, result_buffer)); - // Because the result is written into result_buffer, a null ShapedBuffer - // pointer should have been returned. - CHECK_EQ(nullptr, null_buffer.get()); - return tensorflow::Status::OK(); -} - -StatusOr>> -LocalService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - std::vector> module_configs; - for (const AheadOfTimeComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, - /*include_unused_parameters=*/true)); - hlo_modules.push_back(std::move(hlo_module)); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - module_configs.push_back(MakeUnique(*program_shape)); - HloModuleConfig* module_config = module_configs.back().get(); - auto* computation_layout = - module_config->mutable_entry_computation_layout(); - for (int i = 0; i < instance.argument_layouts.size(); ++i) { - const Shape& argument_layout = *instance.argument_layouts[i]; - if (ShapeUtil::IsTuple(argument_layout)) { - return Unimplemented("tuple arguments not supported yet"); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - argument_layout)); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( - *instance.result_layout)); - } - - return execute_backend_->compiler()->CompileAheadOfTime( - std::move(hlo_modules), std::move(module_configs), MakeHloDumper(), - options); -} - -tensorflow::Status LocalService::ValidateExecuteOptions( - const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice argument_layouts, - const LocalExecuteOptions& options, - const ShapedBuffer* preallocated_result_buffer) { - if (argument_layouts.size() != program_shape.parameters_size()) { - return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", - program_shape.parameters_size(), argument_layouts.size()); - } - - if (options.stream()) { - if (!options.stream()->ok()) { - return InvalidArgument("stream is uninitialized or in an error state"); - } - - // Check stream matches service platform. - const se::Platform* stream_platform = - options.stream()->parent()->platform(); - if (stream_platform != execute_backend_->platform()) { - return InvalidArgument( - "stream is for platform %s, but service targets platform %s", - stream_platform->Name().c_str(), - execute_backend_->platform()->Name().c_str()); - } - - // Cannot specify platform or device_ordinal with a stream. The stream - // determines these values. - if (options.device_ordinal() >= 0) { - return InvalidArgument( - "cannot set both device ordinal and stream options in " - "LocalExecuteOptions; the stream determines the device ordinal"); - } - if (options.platform()) { - return InvalidArgument( - "cannot set both platform and stream options in " - "LocalExecuteOptions; the stream determines the platform"); - } - } - if (options.platform() && - options.platform() != execute_backend_->platform()) { - return InvalidArgument( - "service platform (%s) does not match platform set in " - "LocalExecuteOptions (%s)", - execute_backend_->platform()->Name().c_str(), - options.platform()->Name().c_str()); - } - - // TODO(cwhipkey): validate the thread pool provided? - - if (!options.allocator()) { - return InvalidArgument("an allocator must be provided to ExecuteLocally"); - } - - if (options.allocator()->platform() != execute_backend_->platform()) { - return InvalidArgument( - "allocator platform (%s) does not match service platform (%s)", - options.allocator()->platform()->Name().c_str(), - execute_backend_->platform()->Name().c_str()); - } - - if (preallocated_result_buffer != nullptr) { - if (options.result_layout()) { - return InvalidArgument( - "cannot set both result ShapedBuffer and result layout; the result " - "ShapedBuffer determines the result layout"); - } - if (!ShapeUtil::Compatible(preallocated_result_buffer->shape(), - program_shape.result())) { - return InvalidArgument( - "result ShapedBuffer of shape %s not compatible with computation " - "result shape %s", - ShapeUtil::HumanString(preallocated_result_buffer->shape()).c_str(), - ShapeUtil::HumanString(program_shape.result()).c_str()); - } - } - if (options.result_layout()) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(*options.result_layout(), - program_shape.result())); - } - - // Check that all argument layouts are valid and the right shape. - for (int i = 0; i < argument_layouts.size(); ++i) { - const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); - if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { - return InvalidArgument( - "invalid argument shape for argument %d, expected %s, got %s", i, - ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); - } - } - - return tensorflow::Status::OK(); -} - -StatusOr> LocalService::ExecuteLocallyInternal( - const ComputationHandle& computation, - tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options, - ShapedBuffer* preallocated_result_buffer) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Determine device ordinal the computation will run on. - int device_ordinal; - if (options.device_ordinal() >= 0) { - device_ordinal = options.device_ordinal(); - } else if (options.stream()) { - device_ordinal = options.stream()->parent()->device_ordinal(); - } else { - device_ordinal = execute_backend_->default_device_ordinal(); - } - - // Check that all arguments are on the right platform and device ordinal. - std::vector argument_layouts(arguments.size()); - for (int i = 0; i < arguments.size(); ++i) { - auto argument = arguments[i]; - if (argument->platform() != execute_backend_->platform() || - argument->device_ordinal() != device_ordinal) { - return InvalidArgument( - "computation to run on device %s but argument %d is on " - "device %s:%d", - execute_backend_->device_name(device_ordinal).c_str(), i, - argument->platform()->Name().c_str(), argument->device_ordinal()); - } - argument_layouts[i] = &argument->shape(); - } - - TF_RETURN_IF_ERROR(ValidateExecuteOptions( - *program_shape, argument_layouts, options, preallocated_result_buffer)); - - // Construct computation layout from the argument layouts. - auto module_config = MakeUnique(*program_shape); - module_config->set_has_hybrid_result(true); - module_config->set_replica_count(execute_backend_->Replicas().size()); - std::vector argument_buffers; - auto* computation_layout = module_config->mutable_entry_computation_layout(); - for (int i = 0; i < arguments.size(); ++i) { - const ShapedBuffer* argument = arguments[i]; - if (ShapeUtil::IsTuple(argument->shape())) { - return Unimplemented("tuple arguments not supported yet"); - } - argument_buffers.push_back(argument->buffer(/*index=*/{})); - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - argument->shape())); - } - if (options.result_layout()) { - TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( - *options.result_layout())); - } else if (preallocated_result_buffer != nullptr) { - TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( - preallocated_result_buffer->shape())); - } else { - computation_layout->mutable_result_layout()->SetToDefaultLayout(); - } - - ExecutableRunOptions run_options; - run_options.set_allocator(options.allocator()); - run_options.set_inter_op_thread_pool( - execute_backend_->inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - execute_backend_->eigen_intra_op_thread_pool_device()); - - // "acquired_stream" owns the stream used for execution if no stream is given. - std::unique_ptr acquired_stream; - if (options.stream()) { - run_options.set_stream(options.stream()); - } else { - se::StreamExecutor* stream_executor; - if (options.device_ordinal() >= 0) { - TF_ASSIGN_OR_RETURN( - stream_executor, - execute_backend_->stream_executor(options.device_ordinal())); - } else { - stream_executor = execute_backend_->default_stream_executor(); - } - TF_ASSIGN_OR_RETURN(acquired_stream, - execute_backend_->AcquireStream(stream_executor)); - run_options.set_stream(acquired_stream.get()); - } - auto stream_releaser = - ::tensorflow::gtl::MakeCleanup([this, &acquired_stream]() { - if (acquired_stream != nullptr) { - execute_backend_->ReleaseStream(std::move(acquired_stream)); - } - }); - - ExecutionProfile* profile = options.execution_profile(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - argument_buffers, execute_backend_.get(), - run_options.stream()->parent(), profile)); - - if (preallocated_result_buffer == nullptr) { - return Service::ExecuteOnStreamWrapper< - StatusOr>>( - executable.get(), &run_options, profile, - [&arguments](Executable* executable, - const ExecutableRunOptions* run_options, - HloExecutionProfile* hlo_execution_profile) { - return executable->ExecuteOnStream(run_options, arguments, - hlo_execution_profile); - }); - } else { - TF_RETURN_IF_ERROR(Service::ExecuteOnStreamWrapper( - executable.get(), &run_options, profile, - [&arguments, preallocated_result_buffer]( - Executable* executable, const ExecutableRunOptions* run_options, - HloExecutionProfile* hlo_execution_profile) { - return executable->ExecuteOnStream(run_options, arguments, - preallocated_result_buffer, - hlo_execution_profile); - })); - // To satisfy the return value type, Return a null ShapedBuffer pointer. - return std::unique_ptr(); - } -} - StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -519,6 +153,10 @@ StatusOr> LocalService::CompileExecutable( auto module_config = MakeUnique(*program_shape); module_config->set_has_hybrid_result(has_hybrid_result); module_config->set_replica_count(execute_backend_->Replicas().size()); + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + if (flags->xla_hlo_profile) { + module_config->enable_hlo_profiling(true); + } auto* computation_layout = module_config->mutable_entry_computation_layout(); for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& shape = *argument_layouts[i]; diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 9fe0d5993b3..767a3ab697f 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -31,60 +31,6 @@ limitations under the License. namespace xla { -// Computation execution options which may be set by the client when executing -// locally (via LocalClient::ExecuteLocally). -class LocalExecuteOptions { - public: - // Specifies the allocator to use during execution. Execution will fail if no - // allocator is provided. - LocalExecuteOptions& set_allocator(DeviceMemoryAllocator* allocator); - DeviceMemoryAllocator* allocator() const; - - // If set, this is the platform to run the computation on. This must match - // the underlying platform of the service. A value of nullptr means the - // platform is not set. - // TODO(b/28616830): Support multiple platforms. - LocalExecuteOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; - - // If set, this is the device to run the computation on. Valid device_ordinal - // values are: 0 to # of devices - 1. These values are identical to the - // device ordinal values used by StreamExecutor. A value of < 0 means the - // ordinal is not set. - LocalExecuteOptions& set_device_ordinal(int device_ordinal); - int device_ordinal() const; - - // If set, this is the stream to run the computation on. The platform of the - // stream must match the service's platform. The device ordinal - // option (if set) must match the stream's device. A value of nullptr means - // the stream is not set. - LocalExecuteOptions& set_stream(perftools::gputools::Stream* stream); - perftools::gputools::Stream* stream() const; - - // If set, collect profile information during execution and fill the given - // ExecutionProfile object with the profile data. A value of nullptr means - // the profile is not set. - LocalExecuteOptions& set_execution_profile(ExecutionProfile* profile); - ExecutionProfile* execution_profile() const; - - // If set, this specifies the layout of the result of the computation. If not - // set, the service will chose the layout of the result. A Shape is used to - // store the layout to accomodate tuple result shapes. A value of nullptr - // means the shape is not set. - LocalExecuteOptions& set_result_layout(const Shape& shape_with_layout); - const Shape* result_layout() const; - - private: - DeviceMemoryAllocator* allocator_ = nullptr; - perftools::gputools::Platform* platform_ = nullptr; - int device_ordinal_ = -1; - perftools::gputools::Stream* stream_ = nullptr; - ExecutionProfile* profile_ = nullptr; - - bool has_result_shape_with_layout_ = false; - Shape result_shape_with_layout_; -}; - // Service implementation that extends the XLA Service to leverage running // in the same process as the client. class LocalService : public Service { @@ -97,14 +43,6 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // For an array of arguments, validate that each is placed on the - // specified device_ordinal, and return the DeviceMemoryBase - // corresponding to each argument. - tensorflow::Status ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs); - // Return a handle to a buffer large enough to hold shape, allocated // on device_ordinal. If allocate_space_for_deep_copy, the buffer is // large enough to hold all sub-buffers of a tuple shape, otherwise @@ -113,48 +51,6 @@ class LocalService : public Service { const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy); - // Execute the given computation with the given arguments and options with - // zero-copy data handling of arguments and result. - StatusOr> ExecuteLocally( - const ComputationHandle& computation, - tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options); - - // Overload which writes the result into the given ShapedBuffer "result". - // Due to aliasing, not all buffers which comprise "result" may be utilized - // in the computation and thus be uninitialized. The |ShapedBuffer::buffer| - // or |ShapedBuffer::mutable_buffer| methods should be used to map an index to - // the initialized buffer. - // - // For example: - // Let 'result' be a ShapedBuffer holding a tuple with the same element, - // 'x', twice: (x, x). It is incorrect to assume that the second buffer - // which comprises 'result' is initialized. Instead, a mapping has been - // added to 'result' which can be used to recover the correct buffer. - // In this case, result->buffer({0}) should be used to extract the address of - // the first tuple element while result->buffer({1}) should be used for the - // second. - tensorflow::Status ExecuteLocally( - const ComputationHandle& computation, - tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options, ShapedBuffer* result_buffer); - - // A description of a computation to compile using CompileAheadOfTime. - struct AheadOfTimeComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |LocalClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& Options); - // Builds an Executable with the given argument layouts and options. If // result_layout is non-null, then the executable is compiled to produce a // result of the given layout. @@ -168,24 +64,6 @@ class LocalService : public Service { std::unique_ptr compute_constant_backend); LocalService(const LocalService&) = delete; void operator=(const LocalService&) = delete; - - // Internal helper for executing a computation. If result_buffer is null then - // the result is returned as a ShapedBuffer. If result_buffer is non-null then - // the result is written into result_buffer and a null ShapedBuffer pointer is - // returned. - StatusOr> ExecuteLocallyInternal( - const ComputationHandle& computation, - tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options, - ShapedBuffer* preallocated_result_buffer); - - // Validates the given options and argument layouts and returns an appropriate - // error code. - tensorflow::Status ValidateExecuteOptions( - const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options, - const ShapedBuffer* preallocated_result_buffer); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index 00e4b35d155..d24a592f46e 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -26,9 +27,9 @@ limitations under the License. namespace xla { string LogicalBuffer::ToString() const { - return tensorflow::strings::StrCat(instruction_->name(), "[", + return tensorflow::strings::StrCat(instruction_->FullyQualifiedName(), "[", tensorflow::str_util::Join(index_, ","), - "](#", id_, ")"); + "](#", id_, " @", color_.value(), ")"); } std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer) { @@ -36,4 +37,26 @@ std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer) { return out; } +/*static*/ LogicalBufferProto::Location LogicalBuffer::ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index) { + LogicalBufferProto::Location proto; + proto.set_computation_name(instruction.parent()->name()); + proto.set_instruction_name(instruction.name()); + for (const int64 index_entry : index) { + proto.add_shape_index(index_entry); + } + return proto; +} + +LogicalBufferProto LogicalBuffer::ToProto(const SizeFunction& size_fn) const { + LogicalBufferProto proto; + proto.set_id(id_); + proto.set_size(size_fn(*this)); + LogicalBufferProto::Location proto_location = + ToLocationProto(*instruction_, index_); + proto.mutable_defined_at()->Swap(&proto_location); + proto.set_color(color_.value()); + return proto; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index 35a9935f44c..566cd01ea43 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -16,22 +16,23 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ +#include #include #include #include +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/int_type.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace xla { -struct HashLogicalBuffer; - // Class describing a contiguous sequence of elements (ie, C array) which form // the components of Shaped values in XLA. XLA arrays are trivially a // single LogicalBuffer. Tuple values are made up of more than one @@ -83,6 +84,8 @@ struct HashLogicalBuffer; // LogicalBuffer(%tuple_constant, {1, 1}) // Holds value "43" class LogicalBuffer { public: + TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); + // Id is a unique identifier for the LogicalBuffer to facilitate efficient // collections of LogicalBuffers with stable iteration order. // LogicalBuffers are typically created and accessed through @@ -90,11 +93,13 @@ class LogicalBuffer { // unique value. using Id = int64; - // Function which returns the size of a logical buffer in bytes. + // Functions which return the size and alignment of a logical buffer in bytes. using SizeFunction = std::function; + using AlignmentFunction = std::function; - LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id) - : instruction_(instruction), index_(index), id_(id) {} + LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id, + Color color) + : instruction_(instruction), index_(index), id_(id), color_(color) {} Id id() const { return id_; } @@ -105,6 +110,11 @@ class LogicalBuffer { // defined. Index used defined as in ShapeUtil::GetSubshape() const ShapeIndex& index() const { return index_; } + // Return the color of the logical buffer. Differently colored buffers can + // not be parts of the same allocation. + Color color() const { return color_; } + void set_color(Color color) { color_ = color; } + // Return the shape of the buffer. This reference points into the shape field // of the instruction defining the buffer. Therefore, the returned shape will // contain the layout of instruction, if any. @@ -126,29 +136,24 @@ class LogicalBuffer { bool IsArray() const { return ShapeUtil::IsArray(shape()); } string ToString() const; + LogicalBufferProto ToProto(const SizeFunction& size_fn) const; + + // Returns the LogicalBufferProto::Location that serializes the given + // instruction and index. + static LogicalBufferProto::Location ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index); private: - friend struct HashLogicalBuffer; HloInstruction* instruction_; ShapeIndex index_; Id id_; + Color color_; // Similar to HLO constructs (HloInstruction, etc), pointers are used for // comparison to equality, so disable all copying. TF_DISALLOW_COPY_AND_ASSIGN(LogicalBuffer); }; -struct HashLogicalBuffer { - size_t operator()(const LogicalBuffer& b) const { - std::hash hasher; - size_t h = hasher(b.instruction_); - for (int i = 0; i < b.index_.size(); i++) { - h += static_cast(b.index_[i] << i); - } - return h; - } -}; - std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer); } // namespace xla diff --git a/tensorflow/compiler/xla/service/pool.h b/tensorflow/compiler/xla/service/pool.h new file mode 100644 index 00000000000..8e710ebb6dc --- /dev/null +++ b/tensorflow/compiler/xla/service/pool.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_POOL_H_ + +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { + +// Pool of values, which are created as needed and destroyed when the `Pool` is +// destroyed +template +class Pool { + public: + struct Deleter { + void operator()(T* ptr) { pool->Deallocate(ptr); } + + Pool* pool; + }; + + // A pointer to a taken element of a `Pool` which returns it to the pool on + // destruction + using SmartPtr = std::unique_ptr; + + // Constructs a `Pool` with given factory function, which need not be + // thread-safe. + explicit Pool(std::function()> factory) + : factory_(factory) {} + + explicit Pool() : Pool([]() { return MakeUnique(); }) {} + + // Returns a pointer to a value in the pool, creating a new value if none is + // free. The returned smart pointer returns the element to the pool on + // destruction. + // + // This method is thread-safe. + SmartPtr Allocate() { + tensorflow::mutex_lock lock(mu_); + T* ptr; + if (!xs_.empty()) { + ptr = std::move(xs_.back()).release(); + xs_.pop_back(); + } else { + ptr = factory_().release(); + } + Deleter del = {this}; + return std::unique_ptr(ptr, del); + } + + private: + // Puts a pointer to a value back into the pool, leaving it free for future + // use. + // + // This method is thread-safe. + void Deallocate(T* ptr) { + tensorflow::mutex_lock lock(mu_); + xs_.push_back(std::unique_ptr(ptr)); + } + + const std::function()> factory_ GUARDED_BY(mu_); + std::vector> xs_ GUARDED_BY(mu_); + tensorflow::mutex mu_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_POOL_H_ diff --git a/tensorflow/compiler/xla/service/pool_test.cc b/tensorflow/compiler/xla/service/pool_test.cc new file mode 100644 index 00000000000..8c4fe258e38 --- /dev/null +++ b/tensorflow/compiler/xla/service/pool_test.cc @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/pool.h" + +#include "tensorflow/compiler/xla/test_helpers.h" + +namespace xla { +namespace { + +using PoolTest = ::testing::Test; + +TEST_F(PoolTest, Test) { + Pool pool; + + { + auto ptr = pool.Allocate(); + EXPECT_NE(nullptr, ptr.get()); + *ptr = 5; + } + + auto ptr = pool.Allocate(); + EXPECT_NE(nullptr, ptr.get()); + EXPECT_EQ(5, *ptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index f8023f1c375..2d35ba5e548 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -13,16 +13,98 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Implementation note: +// +// The general idea behind this pass is that we're converting from this: +// %param.A = OldShape +// %param.B = OldShape +// %reshape.A = NewShape reshape(%param.A) +// %reshape.B = NewShape reshape(%param.B) +// %instruction = NewShape instruction(%reshape.A, %reshape.B) +// To this: +// %param.A = OldShape +// %param.B = OldShape +// %instruction = OldShape instruction(%param.A, %param.B) +// %reshape = NewShape reshape(%instruction) +// +// Where the instruction must be elementwise, and both reshapes and transposes +// are moved. +// +// Most elementwise instructions support implicit broadcast of scalar operands, +// but select is a special-case. The signature is Select(Pred, A, B), and the +// only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or +// transposes to a scalar should be cheap, we simply never move them. + #include "tensorflow/compiler/xla/service/reshape_mover.h" #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" namespace xla { namespace { +// Checks if an instruction can change its shape simply by adjusting metadata. +// This is the case if it is: +// +// - an instruction does not have any producers like Constants +// or Rng instruction, or is a scalar. +// +// Or +// +// - an reshape/transpose instruction with an operand that can trivially change +// its shape. +bool InstructionCanTriviallyChangeShape(const HloInstruction* instruction) { + // Reshape/Transposes are only trivial if their operand is trivial. + if (instruction->opcode() == HloOpcode::kReshape || + instruction->opcode() == HloOpcode::kTranspose) { + CHECK_EQ(instruction->operand_count(), 1); + return InstructionCanTriviallyChangeShape(instruction->operand(0)); + } + + // Scalars can operate with any shape. + if (ShapeUtil::IsScalar(instruction->shape())) { + return true; + } + + // A constant can trivially reshape the literal it holds. + if (instruction->opcode() == HloOpcode::kConstant) { + return true; + } + + // An Rng instruction can be any shape as long as it has one user. Two copies + // of the same Rng would be problematic if an Rng of a different shape would + // produce random numbers in a different order. + if (instruction->opcode() == HloOpcode::kRng && + instruction->user_count() == 1) { + return true; + } + return false; +} + +// Finds the first non-scalar operand of an instruction that is a non-trivial +// reshape or transpose. Returns the operand if it is found or nullptr if not +// found. +HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( + const HloInstruction* hlo) { + for (HloInstruction* operand : hlo->operands()) { + if (!ShapeUtil::IsScalar(operand->shape()) && + ((operand->opcode() == HloOpcode::kReshape || + operand->opcode() == HloOpcode::kTranspose) && + !InstructionCanTriviallyChangeShape(operand->operand(0)))) { + VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " + << hlo->ToStringNoMetadata() << ":\n\t" + << operand->ToStringNoMetadata(); + return operand; + } + } + return nullptr; +} + // Returns whether `a` and `b` are equivalent for the purposes of this pass. bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { if (a->opcode() != b->opcode() || @@ -40,82 +122,204 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { } } +// Returns true if an elementwise operation has all operands that can easily +// change shape. Operands can easily change shape if they are all +// reshapes/transposes to and from the same shape. Additionally, operands like +// constant, rng, and any scalar change shape with only an adjustment of +// metadata. bool IsElementwiseOfEquivalentReshapesOrTransposes( const HloInstruction* instruction) { const std::vector& operands = instruction->operands(); - return instruction->IsElementwise() && instruction->operand_count() > 0 && - std::all_of(operands.begin(), operands.end(), - [](const HloInstruction* instruction) { - // We require operand have no other users as otherwise - // this is not a clear win. - return 1 == instruction->users().size(); - }) && - // Check whether each operand beyond the first is equivalent to the - // first. - std::all_of(operands.begin(), operands.end(), - [&operands](const HloInstruction* operand) { - return AreEquivalentReshapes(operands[0], operand); - }); + HloInstruction* first_reshape_operand = + FirstNonScalarAndNonTrivialReshapeOperand(instruction); + // If there are no non-trivial reshapes or transposes, then there is nothing + // to sink below the elementwise operation. + if (!first_reshape_operand) { + return false; + } + VLOG(3) << "** Checking whether instruction is an elementwise operation of " + "equivalent reshapes/transposes: " + << instruction->ToStringNoMetadata(); + bool result = (instruction->user_count() > 0 || + instruction == instruction->parent()->root_instruction()) && + instruction->IsElementwise() && !operands.empty(); + + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, it may be + // implicitly broadcast, which can confound the movement's + // correctness. + // + // And one of the following: + // 1. Are reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Are one of kConstant, kRng, and scalars that can change shape + // trivially, + if (result) { + for (auto& operand : operands) { + if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToStringNoMetadata() + << "\n\tinstruction: " << instruction->ToStringNoMetadata(); + result = false; + break; + } + + if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\toperand: " << operand->ToStringNoMetadata(); + continue; + } + + if (InstructionCanTriviallyChangeShape(operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToStringNoMetadata(); + continue; + } + + // TODO(someone): Look into supporting general ops for the operands as + // well. + VLOG(5) << "Operand is neither equalivant to the first Reshape operand" + "nor can trivially change shape: " + << operand->ToStringNoMetadata(); + result = false; + break; + } + } + + VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for " + << instruction->ToStringNoMetadata() << ": " << result; + return result; } // Try to sink any reshape or transpose operands of `instruction` across it. We -// do so if `instruction` is elementwise and all operands are equivalent -// reshapes or transposes. -bool TrySinkReshapeOrTranspose(HloComputation* computation, - HloInstruction* instruction) { - if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { - std::vector operands = instruction->operands(); - auto old_reshape = operands[0]; - for (size_t i = 0; i < operands.size(); ++i) { - operands[i] = operands[i]->mutable_operand(0); +// do so if `instruction` is elementwise and all operands are either equivalent +// reshapes/transposes or are trivially reshapable. Note that no move is +// performend if there is no nontrivial reshapes/transposes. +StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, + HloInstruction* instruction) { + if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { + return false; + } + + HloInstruction* old_reshape = + FirstNonScalarAndNonTrivialReshapeOperand(instruction); + TF_RET_CHECK(old_reshape != nullptr); + Shape new_elementwise_shape = old_reshape->operand(0)->shape(); + + VLOG(3) << "** Trying to sink reshape or transpose: " + << instruction->ToStringNoMetadata() + << "\n\told reshape: " << old_reshape->ToStringNoMetadata() + << "\n\tnew elementwise shape: " + << ShapeUtil::HumanString(new_elementwise_shape); + + std::vector operands = instruction->operands(); + for (size_t i = 0; i < operands.size(); ++i) { + // All scalar operands remain as-is, even if they're reshape or transpose, + // to simplify handling wrt special scalar broadcast rules for ops like + // Select. Scalar reshapes should be cheap anyways. + if (ShapeUtil::IsScalar(operands[i]->shape())) { + continue; } - auto new_elementwise = - computation->AddInstruction(instruction->CloneWithNewOperands( - // `instruction` may change the element type, e.g., from - // operands[0] -> reshape -> convert (`instruction`) - // to - // operands[0] -> convert' -> reshape' - // - // In this case, convert' should have the same element type as - // `convert` and the same dimensions as operands[0]. - ShapeUtil::MakeShape( - instruction->shape().element_type(), - AsInt64Slice(operands[0]->shape().dimensions())), - operands)); - std::unique_ptr new_reshape; - switch (old_reshape->opcode()) { - case HloOpcode::kReshape: - new_reshape = HloInstruction::CreateReshape(instruction->shape(), - new_elementwise); + PrimitiveType element_type = operands[i]->shape().element_type(); + switch (operands[i]->opcode()) { + case HloOpcode::kConstant: { + if (old_reshape->opcode() == HloOpcode::kReshape) { + VLOG(3) << "Creating reshape for kConstant operand " << i << ": " + << operands[i]->ToStringNoMetadata(); + operands[i] = instruction->parent()->AddInstruction( + HloInstruction::CreateReshape( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i])); + } else { + TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose); + std::vector inverse_permutation = + InversePermutation(old_reshape->dimensions()); + operands[i] = instruction->parent()->AddInstruction( + HloInstruction::CreateTranspose( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i], inverse_permutation)); + } break; + } + case HloOpcode::kRng: { + CHECK_EQ(operands[i]->user_count(), 1); + operands[i] = instruction->parent()->AddInstruction( + operands[i]->CloneWithNewOperands( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i]->operands())); + break; + } + case HloOpcode::kReshape: case HloOpcode::kTranspose: - new_reshape = HloInstruction::CreateTranspose( - instruction->shape(), new_elementwise, old_reshape->dimensions()); + operands[i] = operands[i]->mutable_operand(0); break; default: - LOG(FATAL) << "Bad opcode"; + LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " + "transposes."; } - TF_CHECK_OK(computation->ReplaceWithNewInstruction(instruction, - std::move(new_reshape))); - return true; } - return false; + if (HloOpcode::kFusion == instruction->opcode()) { + // Here we already know `instruction` is elementwise, and no operand is + // implicit broadcast as if it were the operands would not be equivalent + // reshapes, so all the fused instructions have the same dimensions. + for (const auto& fused_instruction : instruction->fused_instructions()) { + Shape* shape = fused_instruction->mutable_shape(); + *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); + *shape->mutable_layout() = new_elementwise_shape.layout(); + } + } + HloInstruction* new_elementwise = + computation->AddInstruction(instruction->CloneWithNewOperands( + // `instruction` may change the element type, e.g., from + // operands[0] -> reshape -> convert (`instruction`) + // to + // operands[0] -> convert' -> reshape' + // + // In this case, convert' should have the same element type as + // `convert` and the same dimensions as operands[0]. + ShapeUtil::ChangeElementType(new_elementwise_shape, + instruction->shape().element_type()), + operands)); + + std::unique_ptr new_reshape; + switch (old_reshape->opcode()) { + case HloOpcode::kReshape: + VLOG(3) << "Creating new reshape for new elementwise op: " + << new_elementwise->ToStringNoMetadata(); + new_reshape = + HloInstruction::CreateReshape(instruction->shape(), new_elementwise); + break; + case HloOpcode::kTranspose: + new_reshape = HloInstruction::CreateTranspose( + instruction->shape(), new_elementwise, old_reshape->dimensions()); + break; + default: + LOG(FATAL) << "Bad opcode"; + } + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, std::move(new_reshape))); + return true; } } // namespace StatusOr ReshapeMover::Run(HloModule* module) { - return std::any_of( - module->computations().begin(), module->computations().end(), - [](const std::unique_ptr& computation) { - std::list postorder = - computation->MakeInstructionPostOrder(); - return std::any_of(postorder.begin(), postorder.end(), - [&computation](HloInstruction* instruction) { - return TrySinkReshapeOrTranspose(computation.get(), - instruction); - }); - }); + bool changed = false; + for (const auto& comp : module->computations()) { + for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool did_change, + TrySinkReshapeOrTranspose(comp.get(), instruction)); + changed |= did_change; + } + } + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 850295c7261..9becdb2bed4 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -20,38 +20,523 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { using ReshapeMoverTest = HloTestBase; -TEST_F(ReshapeMoverTest, ReshapesWithNonSameInputShapesNotMoved) { - auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); +TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 8, 7, 1}), "param0")); - auto reshape2 = + 1, ShapeUtil::MakeShape(F32, {1, 8, 7, 1}), "param1")); + auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); - auto reshape3 = + auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); - auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( - root_shape, HloOpcode::kAdd, reshape2, reshape3)); + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = MakeUnique(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(add4, computation->root_instruction()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(param1))); + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); - EXPECT_EQ(add4, computation->root_instruction()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(param1))); +} + +// For a graph that looks like: +// +// +- reshape0 - rng0 +// | +// +- const1 +// | +// add +// +// where rng0 has a different shape than reshape0. +// +// Verifies that the reshape is not moved, since rng0 is trivially reshapable +// and therefore there is no nontrivial reshapes to move. +TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); + auto rng0 = builder.AddInstruction( + HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), + RandomDistribution::RNG_UNIFORM, {})); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); + + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateFromShape(root_shape))); + + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, const1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(rng0), const1)); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(rng0), const1)); +} + +TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param1")); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, reshape1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(param1))); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT( + computation->root_instruction(), + op::Add(op::Reshape(op::Parameter()), op::Reshape(op::Parameter()))); +} + +TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, reshape1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(param1))); + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Add(param0, param1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); +} + +// For a graph that looks like: +// +// +- reshape2 - param2 +// | +// +- reshape1 - param1 +// | +// +- constant0 +// | +// select +// +// Verifies that the reshape1 and reshape2 sink past select: +// +// +- param2 +// | +// +- param1 +// | +// +- reshape3(constant0) +// | +// select +// | +// reshape4 +TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{true, true, false}, {false, false, true}}))); + + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1")); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); + + auto param2 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param2")); + auto reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param2)); + + builder.AddInstruction(HloInstruction::CreateTernary( + ShapeUtil::MakeShape(PRED, {2, 3}), HloOpcode::kSelect, const0, reshape1, + reshape2)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Select(const0, reshape1, reshape2)); + + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Select(op::Reshape(const0), param1, param2))); + + EXPECT_EQ(const0->shape().DebugString(), + computation->root_instruction()->shape().DebugString()); +} + +// For a graph that looks like: +// +// +- reshape0 - param0 +// | +// +- param1 +// | +// add +// +// Verifies that the reshape0 does not sink below add, because param1 is not +// trivially reshapable nor is a Reshape/Transpose. +TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, param1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), param1)); + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), param1)); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); +} + +// For a graph that looks like: +// +// +- pred +// | +// +- reshape0 - const0 +// | +// +- reshape1 - const1 +// | +// select +// +// Verifies that we don't unnecessarily sink reshapes, which are in fact +// trivial reshapes. +TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0)); + + auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); + + auto pred = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(PRED, {1, 3, 1, 2}), "pred")); + + builder.AddInstruction(HloInstruction::CreateTernary( + root_shape, HloOpcode::kSelect, pred, reshape0, reshape1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Select(pred, op::Reshape(const0), op::Reshape(const1))); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Select(pred, op::Reshape(const0), op::Reshape(const1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); +} + +// For a graph that looks like: +// +// +- reshape0 - param0 +// | +// +- const1 +// | +// add +// +// where there is only 1 non-trivial reshape (reshape0), we sink the reshape +// here for canonicalization benefit: +// +// +- param0 +// | +// +- reshape1 - const1 +// | +// add +// | +// reshape2 +// +// (note that reshape1 here is trivial). +TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0")); + auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, const1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), const1)); + + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Add(param0, op::Reshape(const1)))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); +} + +TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, reshape1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + auto fusion = computation->AddInstruction(HloInstruction::CreateFusion( + add->shape(), HloInstruction::FusionKind::kLoop, add)); + TF_CHECK_OK(computation->ReplaceInstruction(add, fusion)); + + EXPECT_THAT(computation->root_instruction(), + op::Fusion(op::Reshape(param0), op::Reshape(param1))); + + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Fusion(param0, param1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); +} + +TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); + auto pred_shape = ShapeUtil::MakeShape(PRED, {8, 7}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); + auto pred = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(PRED, {1, 8, 1, 7}), "pred")); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); + auto reshape_pred = + builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred)); + builder.AddInstruction(HloInstruction::CreateTernary( + root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT( + computation->root_instruction(), + op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); + + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Select(pred, param0, param1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); +} + +TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {}); + auto pred_shape = ShapeUtil::MakeShape(PRED, {}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "param0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "param1")); + auto pred = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(PRED, {1, 1, 1}), "pred")); + auto reshape_pred = + builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred)); + auto select = builder.AddInstruction(HloInstruction::CreateTernary( + root_shape, HloOpcode::kSelect, reshape_pred, param0, param1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Select(op::Reshape(pred), param0, param1)); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Select(op::Reshape(pred), param0, param1)); + EXPECT_EQ(select, computation->root_instruction()); +} + +// Tree looks like: +// +// param0 [1,128,1] +// | +// reshape [128,1] constant [128,1024] +// \ / +// multiply w/implicit broadcast [128,1024] +// +// The reshape mover would like to sink the reshape below the multiply. +// +// Previously we would attempt to insert a reshape of the constant to [1,128,1] +// (which is unsound, because it has a different number of elements) as +// preparation for sinking the reshape. +// +// To eliminate the unsoundness, we outlaw reshape sinking when one of the +// operands is implicitly broadcast in the elementwise consumer. +// +// TODO(b/37799338) However, it would be possible in this case to do a more +// in-depth analysis to get reshape movement to occur: +// +// 1. Note that the broadcast dimension (logical dimension 1) in the operands +// would map back to logical dimension 2 in the param0 node. +// 2. Match rank of the constant to the param0 node (by prepending a trivial 1 +// dimension). +// 3. Reshape to [128,1024] at the root. +// +// But this is not currently done. +TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0")); + auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {128, 1}), param0)); + Array2D a(128, 1024); + auto literal = LiteralUtil::CreateR2FromArray2D(a); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kMultiply, constant, reshape)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + EXPECT_EQ(multiply, computation->root_instruction()); +} + +// Tree looks like this: +// +// add1 +// | +// +- reshape2 - param2 +// | +// +- reshape3 - add0 +// | +// + reshape0 - param0 +// | +// + reshape1 - param1 +// +// We expect reshape{0,1} AND reshape{2,3} to be lifted. +TEST_F(ReshapeMoverTest, MultiplePasses) { + auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7}); + auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1}); + auto shape3 = ShapeUtil::MakeShape(F32, {8, 7}); + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape1, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape1, "param1")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, shape2, "param2")); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(shape2, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(shape2, param1)); + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + shape2, HloOpcode::kAdd, reshape0, reshape1)); + auto reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape(shape3, param2)); + auto reshape3 = + builder.AddInstruction(HloInstruction::CreateReshape(shape3, add0)); + builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, + reshape2, reshape3)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT( + computation->root_instruction(), + op::Add(op::Reshape(param2), + op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); + + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT( + computation->root_instruction(), + op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1))))); } } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 725808bc887..85ca7e4e59c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -47,12 +46,14 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" namespace se = ::perftools::gputools; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrCat; + namespace xla { namespace { @@ -76,8 +77,10 @@ tensorflow::Status RecordArguments( SessionModule* module) { module->clear_arguments(); for (const Allocation* allocation : arg_allocations) { - TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(), - module->add_arguments())); + Literal argument; + TF_RETURN_IF_ERROR( + LiteralFromAllocation(allocation, allocation->shape(), &argument)); + *module->add_arguments() = argument.ToProto(); } return tensorflow::Status::OK(); } @@ -86,8 +89,11 @@ tensorflow::Status RecordArguments( tensorflow::Status RecordResult(const Allocation* result_allocation, SessionModule* module) { module->clear_result(); - return LiteralFromAllocation(result_allocation, result_allocation->shape(), - module->mutable_result()); + Literal result; + TF_RETURN_IF_ERROR(LiteralFromAllocation( + result_allocation, result_allocation->shape(), &result)); + *module->mutable_result() = result.ToProto(); + return tensorflow::Status::OK(); } } // namespace @@ -109,6 +115,16 @@ ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) { int ServiceOptions::number_of_replicas() const { return number_of_replicas_; } +ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int ServiceOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + /* static */ StatusOr> Service::NewService( perftools::gputools::Platform* platform) { ServiceOptions default_options; @@ -123,9 +139,10 @@ int ServiceOptions::number_of_replicas() const { return number_of_replicas_; } if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - TF_ASSIGN_OR_RETURN( - execute_backend, - Backend::CreateBackend(platform, options.number_of_replicas())); + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_options.set_number_of_replicas(options.number_of_replicas()); + TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); std::unique_ptr service(new Service( @@ -139,37 +156,18 @@ Service::CreateComputeConstantBackend() { PlatformUtil::GetSupportedPlatforms()); for (auto* platform : platforms) { if (platform->id() == se::host::kHostPlatformId) { - return Backend::CreateBackend(platform, /*replica_count=*/1); + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_options.set_number_of_replicas(1); + return Backend::CreateBackend(backend_options); } } return NotFound("CPU platform not found"); } -/* static */ void Service::DumpExecutedHlo(const HloModule& module, - const string& label, - const HloExecutionProfile* profile) { - VLOG(2) << "module name = " << module.name(); - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - if (!flags->xla_generate_hlo_graph.empty() && - RE2::PartialMatch(module.name(), flags->xla_generate_hlo_graph)) { - hlo_graph_dumper::DumpGraph(*module.entry_computation(), label, - flags->xla_hlo_graph_addresses, - flags->xla_hlo_graph_layout, profile); - } - if (!flags->xla_log_hlo_text.empty() && - RE2::PartialMatch(module.name(), flags->xla_log_hlo_text)) { - LOG(INFO) << "HLO for module " << module.name(); - LOG(INFO) << "Label: " << label; - XLA_LOG_LINES(2, module.ToString()); - } - if (!flags->xla_dump_hlo_text_to.empty()) { - hlo_graph_dumper::DumpText(module, label, flags->xla_dump_hlo_text_to); - } -} - /* static */ Compiler::HloDumper Service::MakeHloDumper() { return [](const HloModule& module, const string& label) { - return DumpExecutedHlo(module, label, /*profile=*/nullptr); + return Executable::DumpExecutedHlo(module, label, /*profile=*/nullptr); }; } @@ -177,20 +175,24 @@ Service::Service(std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) : execute_backend_(std::move(execute_backend)), compute_constant_backend_(std::move(compute_constant_backend)) { - LOG(INFO) << "XLA service executing computations on platform " - << execute_backend_->platform()->Name() << ". Devices:"; - for (int i = 0; i < execute_backend_->device_count(); ++i) { - if (execute_backend_->device_ordinal_supported(i)) { - se::StreamExecutor* executor = - execute_backend_->stream_executor(i).ValueOrDie(); - const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << tensorflow::strings::Printf( - " StreamExecutor device (%d): %s, %s", i, description.name().c_str(), - description.platform_version().c_str()); - } else { - LOG(INFO) << tensorflow::strings::Printf( - " StreamExecutor device (%d) not supported", i); + if (execute_backend_) { + LOG(INFO) << Printf( + "XLA service %p executing computations on platform %s. Devices:", this, + execute_backend_->platform()->Name().c_str()); + for (int i = 0; i < execute_backend_->device_count(); ++i) { + if (execute_backend_->device_ordinal_supported(i)) { + se::StreamExecutor* executor = + execute_backend_->stream_executor(i).ValueOrDie(); + const auto& description = executor->GetDeviceDescription(); + LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, + description.name().c_str(), + description.platform_version().c_str()); + } else { + LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + } } + } else { + VLOG(1) << "XLA compile-only service constructed"; } } @@ -202,6 +204,8 @@ tensorflow::Status Service::Computation(const ComputationRequest* arg, *result->mutable_computation() = computation_tracker_.NewComputation(arg->name()); + VLOG(1) << Printf("Created new computation %s on service %p", + result->computation().ShortDebugString().c_str(), this); return tensorflow::Status::OK(); } @@ -251,13 +255,12 @@ StatusOr> Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, const Backend* backend, int device_ordinal) { std::vector allocations; - for (int i = 0; i < arguments.size(); ++i) { + for (size_t i = 0; i < arguments.size(); ++i) { auto allocation_status = allocation_tracker_.Resolve(*arguments[i]); if (!allocation_status.ok()) { return Status(allocation_status.status().code(), - tensorflow::strings::StrCat( - allocation_status.status().error_message(), ", ", - "failed to resolve allocation for parameter ", i)); + StrCat(allocation_status.status().error_message(), ", ", + "failed to resolve allocation for parameter ", i)); } const Allocation* allocation = allocation_status.ValueOrDie(); @@ -265,7 +268,7 @@ StatusOr> Service::ResolveAndValidateArguments( if (allocation->backend() != backend || allocation->device_ordinal() != device_ordinal) { return InvalidArgument( - "argument %d is on device %s but computation will be executed " + "argument %lu is on device %s but computation will be executed " "on device %s", i, allocation->backend() @@ -282,7 +285,7 @@ StatusOr> Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options) { + const ExecutionOptions& execution_options, Backend* backend) { auto module_config = MakeUnique(program_shape); auto* computation_layout = module_config->mutable_entry_computation_layout(); @@ -291,13 +294,13 @@ StatusOr> Service::CreateModuleConfig( program_shape.parameters_size(), arguments.size()); } - for (int i = 0; i < arguments.size(); ++i) { + for (size_t i = 0; i < arguments.size(); ++i) { // Verify that shape of arguments matches the shape of the arguments in the // ProgramShape. if (!ShapeUtil::Compatible(arguments[i]->shape(), program_shape.parameters(i))) { return InvalidArgument( - "computation expects parameter %d to have shape %s, given shape %s", + "computation expects parameter %lu to have shape %s, given shape %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(arguments[i]->shape()).c_str()); } @@ -322,9 +325,9 @@ StatusOr> Service::CreateModuleConfig( module_config->enable_hlo_profiling(true); } - module_config->set_replica_count(execute_backend_->Replicas().size()); - module_config->set_fast_math_disabled(execution_options.disable_fast_math()); + module_config->set_replica_count(backend->Replicas().size()); module_config->set_seed(execution_options.seed()); + module_config->set_debug_options(execution_options.debug_options()); return std::move(module_config); } @@ -334,6 +337,8 @@ StatusOr>> Service::BuildExecutables( std::vector> module_configs, Backend* backend, std::vector executors) { + VLOG(1) << Printf("BuildExecutable on service %p", this); + // Dump computation proto state if flag is set. std::vector> session_modules; legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); @@ -345,11 +350,10 @@ StatusOr>> Service::BuildExecutables( computation_tracker_.SnapshotComputation( versioned_handles[i].handle)); if (!directory_path.empty()) { - string filename = - tensorflow::strings::Printf("computation_%lld__%s__version_%lld", - versioned_handles[i].handle.handle(), - session_module->entry().name().c_str(), - versioned_handles[i].version); + string filename = Printf("computation_%lld__%s__version_%lld", + versioned_handles[i].handle.handle(), + session_module->entry().name().c_str(), + versioned_handles[i].version); TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, *session_module)); session_modules.push_back(std::move(session_module)); @@ -357,29 +361,31 @@ StatusOr>> Service::BuildExecutables( } } - VLOG(1) << "building executables from:"; + VLOG(1) << "Computation handles:"; for (const VersionedComputationHandle& versioned_handle : versioned_handles) { - VLOG(1) << versioned_handle.handle.handle() << "@v" - << versioned_handle.version; + VLOG(1) << versioned_handle; } + CHECK_EQ(versioned_handles.size(), module_configs.size()); std::vector> modules; - for (const VersionedComputationHandle& versioned_handle : versioned_handles) { + for (int64 i = 0; i < versioned_handles.size(); ++i) { + const VersionedComputationHandle& versioned_handle = versioned_handles[i]; + const HloModuleConfig& config = *module_configs[i]; TF_ASSIGN_OR_RETURN(auto module, computation_tracker_.BuildHloModule( - versioned_handle, - /*include_unused_parameters=*/true)); + versioned_handle, config, + /*include_unreachable_instructions=*/true)); modules.push_back(std::move(module)); } Compiler::HloDumper hlo_dumper = MakeHloDumper(); - TF_ASSIGN_OR_RETURN(std::vector> executables, - backend->compiler()->Compile( - std::move(modules), std::move(module_configs), - hlo_dumper, std::move(executors))); + TF_ASSIGN_OR_RETURN( + std::vector> executables, + backend->compiler()->Compile(std::move(modules), hlo_dumper, + std::move(executors))); if (!other_directory_path.empty()) { - for (int64 i = 0; i < versioned_handles.size(); ++i) { + for (size_t i = 0; i < versioned_handles.size(); ++i) { executables[i]->set_session_module(std::move(session_modules[i])); } } @@ -394,6 +400,9 @@ StatusOr> Service::BuildExecutable( const tensorflow::gtl::ArraySlice arguments, Backend* backend, se::StreamExecutor* executor) { + VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, + versioned_handle.ToString().c_str()); + // Dump computation proto state if flag is set. std::unique_ptr session_module; legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); @@ -405,24 +414,20 @@ StatusOr> Service::BuildExecutable( session_module, computation_tracker_.SnapshotComputation(versioned_handle.handle)); if (!directory_path.empty()) { - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__version_%lld", - versioned_handle.handle.handle(), - session_module->entry().name().c_str(), versioned_handle.version); + string filename = Printf("computation_%lld__%s__version_%lld", + versioned_handle.handle.handle(), + session_module->entry().name().c_str(), + versioned_handle.version); TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, *session_module)); } } - VLOG(1) << tensorflow::strings::Printf("building executable %lld@v%lld", - versioned_handle.handle.handle(), - versioned_handle.version); - TF_ASSIGN_OR_RETURN( std::unique_ptr module, - computation_tracker_.BuildHloModule( - versioned_handle, - /*include_unused_parameters=*/!executable_for_compute_constant)); + computation_tracker_.BuildHloModule(versioned_handle, *module_config, + /*include_unreachable_instructions=*/ + !executable_for_compute_constant)); Compiler::HloDumper hlo_dumper = MakeHloDumper(); if (executable_for_compute_constant && @@ -432,8 +437,7 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend->compiler()->Compile(std::move(module), std::move(module_config), - hlo_dumper, executor)); + backend->compiler()->Compile(std::move(module), hlo_dumper, executor)); if (!other_directory_path.empty()) { executable->set_session_module(std::move(session_module)); @@ -471,7 +475,7 @@ StatusOr> Service::BuildAndCacheExecutable( std::unique_ptr executable_unique_ptr, BuildExecutable(versioned_handle, std::move(module_config), /*executable_for_compute_constant=*/false, arguments, - execute_backend_.get(), executor)); + backend, executor)); if (profile != nullptr) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -498,36 +502,30 @@ Service::ExecuteParallelAndRegisterResult( TF_RET_CHECK(backend->Replicas().size() == 1); // Set up streams. - std::vector> streams; - - auto stream_releaser = ::tensorflow::gtl::MakeCleanup([backend, &streams]() { - for (std::unique_ptr& stream : streams) { - backend->ReleaseStream(std::move(stream)); - } - }); + std::vector::SmartPtr> streams; for (se::StreamExecutor* executor : executors) { - TF_ASSIGN_OR_RETURN(std::unique_ptr stream, - backend->AcquireStream(executor)); - // Push back after so that the releaser only sees real streams. + TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } // Set up run options. - std::vector run_options; - for (const std::unique_ptr& stream : streams) { - run_options.emplace_back(); - auto& options = run_options.back(); + std::vector run_options; + for (const Pool::SmartPtr& stream : streams) { + ExecutableRunOptions options; options.set_stream(stream.get()); options.set_allocator(backend->memory_allocator()); options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); + run_options.emplace_back(options, backend->StreamBorrower()); } // Asynchronously launch all executables. std::vector result_handles; - for (int64 i = 0; i < executables.size(); i++) { + for (tensorflow::gtl::ArraySlice::size_type i = 0; + i < executables.size(); i++) { TF_ASSIGN_OR_RETURN( perftools::gputools::DeviceMemoryBase result, executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i])); @@ -555,52 +553,39 @@ StatusOr Service::ExecuteAndRegisterResult( TF_RET_CHECK(!backend->Replicas().empty()); // Set up streams. - std::vector> streams; - - auto stream_releaser = ::tensorflow::gtl::MakeCleanup([backend, &streams]() { - for (std::unique_ptr& stream : streams) { - backend->ReleaseStream(std::move(stream)); - } - }); + std::vector::SmartPtr> streams; for (se::StreamExecutor* executor : backend->Replicas()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr stream, - backend->AcquireStream(executor)); - // Push back after so that the releaser only sees real streams. + TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } // Set up run options. - std::vector run_options; - for (const std::unique_ptr& stream : streams) { - run_options.emplace_back(); - auto& options = run_options.back(); + std::vector run_options; + for (const Pool::SmartPtr& stream : streams) { + ExecutableRunOptions options; options.set_stream(stream.get()); options.set_allocator(backend->memory_allocator()); options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); + run_options.emplace_back(options, backend->StreamBorrower(), + backend->inter_op_thread_pool()); } perftools::gputools::DeviceMemoryBase result; if (backend->Replicas().size() == 1) { TF_ASSIGN_OR_RETURN( - result, ExecuteOnStreamWrapper>( - executable, &run_options[0], profile, - [&arguments](Executable* executable, - const ExecutableRunOptions* run_options, - HloExecutionProfile* hlo_execution_profile) { - return executable->ExecuteOnStream(run_options, arguments, - hlo_execution_profile); - })); + result, executable->ExecuteOnStreamWrapper( + &run_options[0], profile, arguments)); } else { std::vector< tensorflow::gtl::ArraySlice> repeated_arguments(backend->Replicas().size(), arguments); - TF_ASSIGN_OR_RETURN( - auto results, - executable->ExecuteOnStreams(run_options, repeated_arguments)); + TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( + run_options, repeated_arguments)); TF_RET_CHECK(!results.empty()); result = results[0]; } @@ -668,6 +653,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), executor->device_ordinal())); std::vector arguments; + arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { arguments.push_back(allocation->device_memory()); } @@ -676,7 +662,8 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // the program and the argument allocations. TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, CreateModuleConfig(*program_shape, arg_allocations, - request.execution_options())); + request.execution_options(), + execute_backend_.get())); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -695,6 +682,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, BuildExecutables(versioned_handles, std::move(module_configs), execute_backend_.get(), executors)); std::vector executable_ptrs; + executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { executable_ptrs.push_back(executable.get()); } @@ -761,14 +749,16 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options(), execute_backend_.get())); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); std::vector arguments; + arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { arguments.push_back(allocation->device_memory()); } @@ -828,14 +818,16 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options(), execute_backend_.get())); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); std::vector arguments; + arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { arguments.push_back(allocation->device_memory()); } @@ -851,23 +843,16 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, TF_RET_CHECK(!execute_backend_->Replicas().empty()); // Set up streams. - std::vector> streams; - - auto stream_releaser = ::tensorflow::gtl::MakeCleanup([this, &streams]() { - for (std::unique_ptr& stream : streams) { - execute_backend_->ReleaseStream(std::move(stream)); - } - }); + std::vector::SmartPtr> streams; for (se::StreamExecutor* executor : execute_backend_->Replicas()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr stream, - execute_backend_->AcquireStream(executor)); - // Push back after so that the releaser only sees real streams. + TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + execute_backend_->BorrowStream(executor)); streams.push_back(std::move(stream)); } perftools::gputools::DeviceMemoryBase result_data; - for (const std::unique_ptr& stream : streams) { + for (const Pool::SmartPtr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); options.set_allocator(execute_backend_->memory_allocator()); @@ -875,8 +860,12 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, options.set_intra_op_thread_pool( execute_backend_->eigen_intra_op_thread_pool_device()); - TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase this_result_data, - executable->ExecuteAsyncOnStream(&options, arguments)); + ServiceExecutableRunOptions service_options( + options, execute_backend_->StreamBorrower()); + + TF_ASSIGN_OR_RETURN( + perftools::gputools::DeviceMemoryBase this_result_data, + executable->ExecuteAsyncOnStream(&service_options, arguments)); // Take the first result. if (result_data == nullptr) { @@ -927,13 +916,15 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, literal_shape = &allocation->shape(); } - return LiteralFromAllocation(allocation, *literal_shape, - result->mutable_literal()); + Literal literal; + auto status = LiteralFromAllocation(allocation, *literal_shape, &literal); + *result->mutable_literal() = literal.ToProto(); + return status; } tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - const Literal& literal = arg->literal(); + Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { @@ -945,9 +936,8 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, se::StreamExecutor* stream_executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN( - stream_executor, - execute_backend_->stream_executor(arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor( + arg->device_handle().handle())); } else { stream_executor = execute_backend_->default_stream_executor(); } @@ -964,12 +954,10 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, *result->mutable_data() = allocation_tracker_.Register( execute_backend_.get(), stream_executor->device_ordinal(), allocation, - shape, tensorflow::strings::StrCat("TransferToServer literal of size ", - allocation_size)); + shape, StrCat("TransferToServer literal of size ", allocation_size)); - TF_ASSIGN_OR_RETURN( - auto replicas, - execute_backend_->Replicas(stream_executor->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( + stream_executor->device_ordinal())); for (se::StreamExecutor* executor : replicas) { TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( @@ -984,25 +972,51 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( "%s", - tensorflow::strings::StrCat( - "The replica_id=", arg->replica_id(), - " on TransferToInfeedRequest not in range [0, replica_count=", - replica_count, ").") + StrCat("The replica_id=", arg->replica_id(), + " on TransferToInfeedRequest not in range [0, replica_count=", + replica_count, ").") .c_str()); } se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN( - auto replicas, - execute_backend_->Replicas(arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( + arg->device_handle().handle())); executor = replicas[arg->replica_id()]; } else { executor = execute_backend_->Replicas()[arg->replica_id()]; } return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, arg->literal()); + executor, Literal(arg->literal())); +} + +tensorflow::Status Service::TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { + const int64 replica_count = execute_backend_->Replicas().size(); + if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { + return FailedPrecondition( + "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " + "%lld)", + arg->replica_id(), replica_count); + } + + se::StreamExecutor* executor; + if (arg->has_device_handle()) { + TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( + arg->device_handle().handle())); + executor = replicas[arg->replica_id()]; + } else { + executor = execute_backend_->Replicas()[arg->replica_id()]; + } + + Literal literal; + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( + executor, arg->shape_with_layout(), &literal)); + *result->mutable_literal() = literal.ToProto(); + return tensorflow::Status::OK(); } tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, @@ -1010,71 +1024,6 @@ tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, return execute_backend_->ResetDevices(); } -tensorflow::Status Service::TransferToClientInProcess( - const TransferToClientInProcessRequest* arg, - TransferToClientInProcessResponse* result) { - TF_RETURN_IF_ERROR(CheckRunsInClientProcess("TransferToClientInProcess")); - - TF_ASSIGN_OR_RETURN(const Allocation* allocation, - allocation_tracker_.Resolve(arg->data())); - - void* buffer = reinterpret_cast(arg->buffer()); - int64 size = ShapeUtil::ByteSizeOf(allocation->shape()); - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - allocation->backend()->stream_executor(allocation->device_ordinal())); - - return allocation->backend()->transfer_manager()->TransferBufferFromDevice( - executor, allocation->device_memory(), size, buffer); -} - -tensorflow::Status Service::TransferToServerInProcess( - const TransferToServerInProcessRequest* arg, - TransferToServerInProcessResponse* result) { - TF_RETURN_IF_ERROR(CheckRunsInClientProcess("TransferToServerInProcess")); - - const Shape& shape = arg->shape(); - - if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { - // TODO(b/32990684): Tuple transfers to host end up allocating further - // buffers - implement that correctly. - return Unimplemented( - "Tuple transfers to the device not supported with replication."); - } - - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("shape must have layout"); - } - - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - - const void* buffer = reinterpret_cast(arg->buffer()); - - // Allocate memory on the device, using the stream executor. The size of the - // allocation is obtained by examining the shape of the literal passed from - // the client. An allocation handle is returned in the response. - int64 allocation_size = - execute_backend_->transfer_manager()->GetByteSizeRequirement(shape); - se::StreamExecutor* stream_executor = - execute_backend_->default_stream_executor(); - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, - execute_backend_->memory_allocator()->Allocate( - stream_executor->device_ordinal(), allocation_size)); - - *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), stream_executor->device_ordinal(), allocation, - shape, tensorflow::strings::StrCat("TransferToServer literal of size ", - allocation_size)); - - for (se::StreamExecutor* executor : execute_backend_->Replicas()) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferBufferToDevice( - executor, allocation_size, buffer, &allocation)); - } - return tensorflow::Status::OK(); -} - tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, IsConstantResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, @@ -1123,7 +1072,7 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); ExecutionOptions execution_options; - execution_options.set_disable_fast_math(true); + execution_options.mutable_debug_options()->set_xla_enable_fast_math(false); *execution_options.mutable_shape_with_output_layout() = program_shape.result(); @@ -1136,7 +1085,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options)); + CreateModuleConfig(program_shape, {}, execution_options, + compute_constant_backend_.get())); TF_ASSIGN_OR_RETURN( std::shared_ptr executable, @@ -1172,9 +1122,8 @@ tensorflow::Status Service::GetComputationShape( VersionedComputationHandle versioned_handle = computation->GetVersionedHandle(); - TF_ASSIGN_OR_RETURN( - auto program_shape, - computation->ComputeProgramShape(versioned_handle.version)); + TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( + versioned_handle.version)); *result->mutable_program_shape() = *program_shape; return tensorflow::Status::OK(); } @@ -1197,13 +1146,15 @@ tensorflow::Status Service::GetComputationStats( VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig())); MakeHloDumper()(*module, "computation statistics subject"); // Run HLO analysis to get the computation statistics. - HloCostAnalysis analysis; + HloCostAnalysis analysis( + execute_backend_->compiler()->ShapeSizeBytesFunction()); TF_RETURN_IF_ERROR( module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -1241,57 +1192,63 @@ tensorflow::Status Service::AddInstruction( tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); - StatusOr handle; + StatusOr handle_status; switch (arg->op_case()) { case OpRequest::kBinaryOpRequest: - handle = computation->AddBinaryInstruction(arg->binary_op_request()); + handle_status = + computation->AddBinaryInstruction(arg->binary_op_request()); break; case OpRequest::kBroadcastRequest: - handle = computation->AddBroadcastInstruction(arg->broadcast_request()); + handle_status = + computation->AddBroadcastInstruction(arg->broadcast_request()); break; case OpRequest::kCallRequest: { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, computation_tracker_.Resolve(arg->call_request().to_apply())); - handle = computation->AddCallInstruction(arg->call_request(), *to_apply); + handle_status = + computation->AddCallInstruction(arg->call_request(), *to_apply); break; } case OpRequest::kConcatenateRequest: - handle = + handle_status = computation->AddConcatenateInstruction(arg->concatenate_request()); break; case OpRequest::kConstantRequest: - handle = computation->AddConstantInstruction(arg->constant_request()); + handle_status = + computation->AddConstantInstruction(arg->constant_request()); break; case OpRequest::kConvertRequest: - handle = computation->AddConvertInstruction(arg->convert_request()); + handle_status = + computation->AddConvertInstruction(arg->convert_request()); break; case OpRequest::kConvolveRequest: - handle = computation->AddConvolveInstruction(arg->convolve_request()); + handle_status = + computation->AddConvolveInstruction(arg->convolve_request()); break; case OpRequest::kCrossReplicaSumRequest: - handle = computation->AddCrossReplicaSumInstruction( + handle_status = computation->AddCrossReplicaSumInstruction( arg->cross_replica_sum_request()); break; case OpRequest::kCustomCallRequest: - handle = + handle_status = computation->AddCustomCallInstruction(arg->custom_call_request()); break; case OpRequest::kDynamicSliceRequest: - handle = + handle_status = computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); break; case OpRequest::kDynamicUpdateSliceRequest: - handle = computation->AddDynamicUpdateSliceInstruction( + handle_status = computation->AddDynamicUpdateSliceInstruction( arg->dynamic_update_slice_request()); break; case OpRequest::kGetTupleElementRequest: - handle = computation->AddGetTupleElementInstruction( + handle_status = computation->AddGetTupleElementInstruction( arg->get_tuple_element_request()); break; case OpRequest::kInfeedRequest: - handle = computation->AddInfeedInstruction(arg->infeed_request()); + handle_status = computation->AddInfeedInstruction(arg->infeed_request()); break; case OpRequest::kOutfeedRequest: TF_RETURN_IF_ERROR( @@ -1301,20 +1258,22 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, computation_tracker_.Resolve(arg->map_request().to_apply())); - handle = computation->AddMapInstruction(arg->map_request(), *to_apply); + handle_status = + computation->AddMapInstruction(arg->map_request(), *to_apply); break; } case OpRequest::kPadRequest: - handle = computation->AddPadInstruction(arg->pad_request()); + handle_status = computation->AddPadInstruction(arg->pad_request()); break; case OpRequest::kParameterRequest: - handle = computation->AddParameterInstruction(arg->parameter_request()); + handle_status = + computation->AddParameterInstruction(arg->parameter_request()); break; case OpRequest::kReduceRequest: { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, computation_tracker_.Resolve(arg->reduce_request().to_apply())); - handle = + handle_status = computation->AddReduceInstruction(arg->reduce_request(), *to_apply); break; } @@ -1322,18 +1281,20 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * to_apply, computation_tracker_.Resolve( arg->reduce_window_request().to_apply())); - handle = computation->AddReduceWindowInstruction( + handle_status = computation->AddReduceWindowInstruction( arg->reduce_window_request(), *to_apply); break; } case OpRequest::kReshapeRequest: - handle = computation->AddReshapeInstruction(arg->reshape_request()); + handle_status = + computation->AddReshapeInstruction(arg->reshape_request()); break; case OpRequest::kReverseRequest: - handle = computation->AddReverseInstruction(arg->reverse_request()); + handle_status = + computation->AddReverseInstruction(arg->reverse_request()); break; case OpRequest::kRngRequest: - handle = computation->AddRngInstruction(arg->rng_request()); + handle_status = computation->AddRngInstruction(arg->rng_request()); break; case OpRequest::kSelectAndScatterRequest: { TF_ASSIGN_OR_RETURN(UserComputation * select, @@ -1342,23 +1303,29 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * scatter, computation_tracker_.Resolve( arg->select_and_scatter_request().scatter())); - handle = computation->AddSelectAndScatterInstruction( + handle_status = computation->AddSelectAndScatterInstruction( arg->select_and_scatter_request(), *select, *scatter); break; } case OpRequest::kSliceRequest: - handle = computation->AddSliceInstruction(arg->slice_request()); + handle_status = computation->AddSliceInstruction(arg->slice_request()); break; case OpRequest::kTernaryOpRequest: - handle = computation->AddTernaryInstruction(arg->ternary_op_request()); + handle_status = + computation->AddTernaryInstruction(arg->ternary_op_request()); break; case OpRequest::kTraceRequest: return computation->AddTraceInstruction(arg->trace_request()); + case OpRequest::kTransposeRequest: + handle_status = + computation->AddTransposeInstruction(arg->transpose_request()); + break; case OpRequest::kUnaryOpRequest: - handle = computation->AddUnaryInstruction(arg->unary_op_request()); + handle_status = computation->AddUnaryInstruction(arg->unary_op_request()); break; case OpRequest::kVariadicOpRequest: - handle = computation->AddVariadicInstruction(arg->variadic_op_request()); + handle_status = + computation->AddVariadicInstruction(arg->variadic_op_request()); break; case OpRequest::kWhileRequest: { TF_ASSIGN_OR_RETURN( @@ -1367,8 +1334,8 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { TF_ASSIGN_OR_RETURN( UserComputation * body, computation_tracker_.Resolve(arg->while_request().body())); - handle = computation->AddWhileInstruction(arg->while_request(), - *condition, *body); + handle_status = computation->AddWhileInstruction(arg->while_request(), + *condition, *body); break; } case OpRequest::kSendRequest: { @@ -1380,13 +1347,19 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { case OpRequest::kRecvRequest: { TF_RETURN_IF_ERROR( channel_tracker_.RegisterRecv(arg->recv_request().channel_handle())); - handle = computation->AddRecvInstruction(arg->recv_request()); + handle_status = computation->AddRecvInstruction(arg->recv_request()); break; } default: return InvalidArgument("Unsupported operation"); } - TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle); + TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); + + // We set the debug metadata here, because we slice off part of the OpRequest + // proto in the above switch statement. + TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status); + TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata())); + return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 9c4b0f44c82..abd1281bdd0 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -63,9 +63,14 @@ class ServiceOptions { ServiceOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; + // Sets the thread pool size for parallel execution of an individual operator. + ServiceOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + private: perftools::gputools::Platform* platform_ = nullptr; int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; }; // The XLA service object, which is the same across all @@ -146,11 +151,6 @@ class Service : public ServiceInterface { const TransferToClientRequest* arg, TransferToClientResponse* result) override; - // Requests that global data be copied into a buffer supplied by the client. - tensorflow::Status TransferToClientInProcess( - const TransferToClientInProcessRequest* arg, - TransferToClientInProcessResponse* result) override; - // Transfers data from a literal provided by the client, into device memory. tensorflow::Status TransferToServer( const TransferToServerRequest* arg, @@ -162,6 +162,12 @@ class Service : public ServiceInterface { const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) override; + // Transfers data from the Outfeed othe device to the literal provided by the + // client. + tensorflow::Status TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; + // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). // @@ -174,11 +180,6 @@ class Service : public ServiceInterface { tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - // Transfers data from a buffer provided by the client, into device memory. - tensorflow::Status TransferToServerInProcess( - const TransferToServerInProcessRequest* arg, - TransferToServerInProcessResponse* result) override; - // Tests if an expression is a compile-time constant. tensorflow::Status IsConstant(const IsConstantRequest* arg, IsConstantResponse* result) override; @@ -243,6 +244,8 @@ class Service : public ServiceInterface { Backend* mutable_backend() { return execute_backend_.get(); } protected: + friend class LocalExecutable; + // The constructor is private. Use the NewService factory to create new // service objects. Service(std::unique_ptr backend, @@ -257,11 +260,11 @@ class Service : public ServiceInterface { tensorflow::gtl::ArraySlice arguments, const Backend* backend, int device_ordinal); - // Create a Hlo module config foe the given program shape and arguments. + // Create a Hlo module config for the given program shape and arguments. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options); + const ExecutionOptions& execution_options, Backend* backend); // Builds an Executable for the given parameters. If // executable_for_compute_constant is true, then the executable is intended to @@ -320,10 +323,6 @@ class Service : public ServiceInterface { executors, tensorflow::gtl::ArraySlice result_tags); - // Dumps the executed HLO according to service-associated flags. - static void DumpExecutedHlo(const HloModule& module, const string& label, - const HloExecutionProfile* profile); - // Returns an HLO dumper for use in the compiler (it refers to flags // associated with the service). static Compiler::HloDumper MakeHloDumper(); @@ -347,21 +346,6 @@ class Service : public ServiceInterface { tensorflow::Status ValidateResultShapeWithLayout( const Shape& shape_with_layout, const Shape& result_shape) const; - // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a - // timer for the execution, sets up HLO profiling if enabled, and fills in the - // given ExecutionProfile if non-null. The given execute_func should be a - // function which calls the desired ExecuteOnStream overload with the supplied - // arguments. The ExecuteOnStream overloads return different types so this - // method is templated on return-type of the execute function. - template - ReturnT ExecuteOnStreamWrapper( - Executable* executable, const ExecutableRunOptions* run_options, - ExecutionProfile* profile, - std::function - execute_func); - // Tracks computations built via the API. ComputationTracker computation_tracker_; @@ -391,73 +375,6 @@ class Service : public ServiceInterface { TF_DISALLOW_COPY_AND_ASSIGN(Service); }; -template -ReturnT Service::ExecuteOnStreamWrapper( - Executable* executable, const ExecutableRunOptions* run_options, - ExecutionProfile* profile, - std::function - execute_func) { - perftools::gputools::Stream* stream = run_options->stream(); - std::unique_ptr timer; - if (profile != nullptr) { - timer.reset(new perftools::gputools::Timer(stream->parent())); - stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); - } - - VLOG(1) << "enqueueing executable on stream..."; - // If the profiling flag isn't enabled, we pass nullptr as the profile to - // indicate profiling is not requested. - HloExecutionProfile hlo_execution_profile; - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - HloExecutionProfile* profile_ptr = - flags->xla_hlo_profile && executable->hlo_profiling_enabled() - ? &hlo_execution_profile - : nullptr; - - auto return_value = execute_func(executable, run_options, profile_ptr); - - if (profile != nullptr) { - VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; - stream->ThenStopTimer(timer.get()).BlockHostUntilDone(); - VLOG(1) << "done with block-host-until-done"; - - // Merge in run time profile information from the executable. - profile->MergeFrom(executable->execution_profile()); - - // Overall execution time (in nanoseconds) from the executor timer. - profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); - - // TODO(b/28123297): On GPU we end up including transfer time in - // the compute time this way. Instead, we should get the correct - // value by measuring it. Setting the field here at least lets - // benchmarks provide *some* value for GPU computations. - // - // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually - // the compute time without the transfer time, so this way we get the - // correct compute time. We should instead have the correct value for - // compute_and_transfer_time and set compute_time to the compute time. - if (profile->compute_time_ns() == 0) { - profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); - } - } - - if (profile_ptr != nullptr) { - HloCostAnalysis analysis; - tensorflow::Status analysis_status = - executable->module().entry_computation()->root_instruction()->Accept( - &analysis); - if (analysis_status.ok()) { - XLA_LOG_LINES(tensorflow::INFO, - profile_ptr->ToString( - stream->parent()->GetDeviceDescription(), analysis)); - } - DumpExecutedHlo(executable->module(), "Service::Execute", profile_ptr); - } - - return return_value; -} } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h new file mode 100644 index 00000000000..017e5ef09ed --- /dev/null +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_EXECUTABLE_RUN_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_EXECUTABLE_RUN_OPTIONS_H_ + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace xla { + +// Class containing options for running a LocalExecutable and other auxiliary +// data, now only a stream cache for GPU backend. +class ServiceExecutableRunOptions { + public: + using StreamBorrower = + std::function::SmartPtr>(int)>; + + explicit ServiceExecutableRunOptions( + ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) + : run_options_(std::move(run_options)), + borrow_stream_(std::move(borrow_stream)), + xla_intra_op_thread_pool_(xla_intra_op_thread_pool) {} + + // Returns reference or pointer to `ExecutableRunOptions` member. + const ExecutableRunOptions& run_options() const { return run_options_; } + ExecutableRunOptions* mutable_run_options() { return &run_options_; } + + // Delegate to `ExecutableRunOptions` member. + perftools::gputools::Stream* stream() const { return run_options_.stream(); } + DeviceMemoryAllocator* allocator() const { return run_options_.allocator(); } + int device_ordinal() const { return run_options_.device_ordinal(); } + + // Borrows a stream and returns a smart pointer which returns the stream on + // destruction. + StatusOr::SmartPtr> BorrowStream( + int device_ordinal) const { + return borrow_stream_ + ? borrow_stream_(device_ordinal) + : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); + } + + // Returns reference to thread pool for execution of XLA ops on CPU backend. + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool() const { + return xla_intra_op_thread_pool_; + } + + private: + ExecutableRunOptions run_options_; + StreamBorrower borrow_stream_; + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_EXECUTABLE_RUN_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto index fa4aa7b0a5f..bb8d1cd2a10 100644 --- a/tensorflow/compiler/xla/service/session.proto +++ b/tensorflow/compiler/xla/service/session.proto @@ -57,15 +57,6 @@ message SessionComputation { // Map from ComputationDataHandle value to operation request. The highest // ComputationDataHandle value corresponds to the root of the computation. map requests = 3; - - // The list of Trace requests in this SessionComputation. - repeated TraceRequest trace_requests = 4; - - // The list of Send requests in this SessionComputation. - repeated SendRequest send_requests = 5; - - // The list of Outfeed requests in this SessionComputation. - repeated OutfeedRequest outfeed_requests = 6; } // Describes a group of SessionComputations with an "entry point" computation @@ -84,10 +75,10 @@ message SessionModule { repeated SessionComputation embedded_computations = 2; // The arguments passed to the computation. - repeated Literal arguments = 3; + repeated LiteralProto arguments = 3; // The result of the computation. - Literal result = 4; + LiteralProto result = 4; // The name of the platform used to run the computation. string execution_platform = 5; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index fbab2dfd4af..d6436cf988d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -208,6 +208,16 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, PrimitiveType_Name(arg.element_type()).c_str()); } return arg; + + case UNOP_IS_FINITE: + if (!ShapeUtil::ElementIsFloating(arg)) { + return InvalidArgument( + "expected element type in shape to be floating point for IsFinite " + "operation; got %s", + PrimitiveType_Name(arg.element_type()).c_str()); + } + return ShapeUtil::ChangeElementType(arg, PRED); + default: return InvalidArgument("unknown operation %s", UnaryOperation_Name(operation).c_str()); @@ -217,7 +227,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferConcatOpShape( tensorflow::gtl::ArraySlice arg_shapes, const int64 dimension) { - if (arg_shapes.size() == 0) { + if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument"); } if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { @@ -234,8 +244,11 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "cannot concatenate arrays with different ranks: %lld vs %lld", - ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape)); + "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "(%s)", + ShapeUtil::Rank(*arg_shape), + ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), + ShapeUtil::HumanString(*shape).c_str()); } if (arg_shape->element_type() != shape->element_type()) { return InvalidArgument( @@ -299,6 +312,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "the rank of the operand and the padding configuration do not match."); } + if (operand_shape.element_type() != padding_value_shape.element_type()) { + return InvalidArgument( + "the element types of the operands to pad do not match"); + } std::vector dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { dimensions[i] = operand_shape.dimensions(i) + @@ -328,7 +345,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // Check if both element types are the same. if (lhs.element_type() != rhs.element_type()) { - return fail("element types mismatch"); + return fail("element types do not match"); } if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || @@ -530,7 +547,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs); } else { // Ranks do not match, so perform InDim broadcasting using - // broadcast_dimensions. Scalar broadcasting is a special case of this). + // broadcast_dimensions. Scalar broadcasting is a special case of this. const Shape& larger_shape = ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs; const Shape& smaller_shape = @@ -623,26 +640,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs)); switch (operation) { case TRIOP_CLAMP: - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation")); - if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) && - (ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) { - return rhs; - } - if (ShapeUtil::Rank(rhs) == 0) { - if (ShapeUtil::Compatible(lhs, ehs)) { - return lhs; - } - return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs; - } - return Unimplemented("not yet implemented: %s, %s %s", - lhs.ShortDebugString().c_str(), - ehs.ShortDebugString().c_str(), - rhs.ShortDebugString().c_str()); + return InferClampShape(lhs, rhs, ehs); case TRIOP_SELECT: return InferSelectShape(lhs, rhs, ehs); case TRIOP_UPDATE: @@ -681,7 +679,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferMapShape( tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply) { - if (arg_shapes.size() == 0) { + if (arg_shapes.empty()) { return InvalidArgument("Map expects at least one argument"); } @@ -1007,7 +1005,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferSliceShape( const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits) { + tensorflow::gtl::ArraySlice limits, + tensorflow::gtl::ArraySlice strides) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s starts={%s} limits={%s}", @@ -1030,13 +1029,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dimension = 0; dimension < starts.size(); ++dimension) { int64 start_index = starts[dimension]; int64 limit_index = limits[dimension]; + int64 stride = strides[dimension]; if (start_index < 0) { return InvalidArgument("negative start index to slice: %lld", start_index); } - if (limit_index < 0) { - return InvalidArgument("negative limit index to slice: %lld", - limit_index); + if (stride == 0) { + return InvalidArgument("Zero stride"); } if (limit_index > arg.dimensions(dimension)) { return InvalidArgument( @@ -1044,18 +1043,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "size (%lld)", limit_index, arg.dimensions(dimension)); } - if (start_index > limit_index) { - return InvalidArgument( - "limit index (%lld) must be greater or equal to " - "start index (%lld) in slice", - limit_index, start_index); - } VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, start_index); VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, limit_index); - - sizes.push_back(limits[dimension] - starts[dimension]); + if (stride > 0) { + if (start_index > limit_index) { + return InvalidArgument( + "limit index (%lld) must be greater or equal to " + "start index (%lld) in slice with positive stride", + limit_index, start_index); + } + sizes.push_back((limit_index - start_index + stride - 1) / stride); + } else { + return InvalidArgument("Negative strides not supported"); + } } return ShapeUtil::MakeShape(arg.element_type(), sizes); @@ -1089,9 +1091,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "dynamic slice start number of dimensions %lld must match rank %lld of " - "slice input", - start_num_dims, ShapeUtil::Rank(operand_shape)); + "dynamic slice start number of dimensions %lld (%s) must match rank " + "%lld of slice input (%s)", + start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), + ShapeUtil::Rank(operand_shape), + ShapeUtil::HumanString(operand_shape).c_str()); } if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { @@ -1103,7 +1107,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dim = 0; dim < slice_sizes.size(); ++dim) { const int64 input_dim_size = operand_shape.dimensions(dim); const int64 slice_dim_size = slice_sizes[dim]; - if (slice_dim_size <= 0) { + if (slice_dim_size < 0) { return InvalidArgument("negative size index to dynamic slice: %lld", slice_dim_size); } @@ -1150,9 +1154,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "dynamic update slice start number of dimensions %lld must match " - "rank %lld of slice input", - start_num_dims, ShapeUtil::Rank(operand_shape)); + "dynamic slice start number of dimensions %lld (%s) must match rank " + "%lld of slice input (%s)", + start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), + ShapeUtil::Rank(operand_shape), + ShapeUtil::HumanString(operand_shape).c_str()); } if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { @@ -1173,9 +1179,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { const int64 input_dim_size = operand_shape.dimensions(dim); const int64 update_dim_size = update_shape.dimensions(dim); - if (update_dim_size <= 0) { + if (update_dim_size < 0) { return InvalidArgument( - "size index %lld to dynamic update slice must be > 0", + "size index %lld to dynamic update slice must be >= 0", update_dim_size); } if (update_dim_size > input_dim_size) { @@ -1322,6 +1328,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand); } +// TODO(b/36794510): Make broadcast semantics more consistent, by supporting +// "degenerate" cases, as with binary elementwise ops. +/* static */ StatusOr ShapeInference::InferClampShape( + const Shape& min, const Shape& operand, const Shape& max) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); + if (!ShapeUtil::SameElementType(min, operand) || + !ShapeUtil::SameElementType(max, operand)) { + return InvalidArgument("clamp op with different operand types: %s, %s, %s", + ShapeUtil::HumanString(min).c_str(), + ShapeUtil::HumanString(operand).c_str(), + ShapeUtil::HumanString(max).c_str()); + } + if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) && + (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) { + return operand; + } + if (ShapeUtil::IsScalar(operand)) { + if (ShapeUtil::Compatible(min, max)) { + return min; + } else if (ShapeUtil::IsScalar(min)) { + return max; + } else if (ShapeUtil::IsScalar(max)) { + return min; + } + } + return Unimplemented( + "not yet implemented: %s, %s %s", min.ShortDebugString().c_str(), + max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); +} + +// TODO(b/36794510): Make broadcast semantics more consistent, by supporting +// "degenerate" cases, as with binary elementwise ops, as well as scalar +// broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { if (!ShapeUtil::Compatible(on_true, on_false)) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index ced2f4d0017..0d270f99794 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,7 +109,8 @@ class ShapeInference { // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] static StatusOr InferSliceShape( const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits); + tensorflow::gtl::ArraySlice limits, + tensorflow::gtl::ArraySlice strides); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. @@ -190,6 +191,10 @@ class ShapeInference { BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); + // Helper for inferring the shape of Clamp ops. + static StatusOr InferClampShape(const Shape& min, const Shape& operand, + const Shape& max); + // Helper for inferring the shape of Select ops. static StatusOr InferSelectShape(const Shape& pred, const Shape& on_true, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 5a1ae6b0024..8c731ae2976 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -20,12 +20,16 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace { +using ::testing::ContainsRegex; +using ::testing::HasSubstr; + class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. @@ -128,23 +132,21 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH( - inferred_status_error1.status().error_message(), - testing::ContainsRegex("operands to select must be the same shape")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("pred operand must have PRED")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH( - inferred_status_error3.status().error_message(), - testing::ContainsRegex("with non-scalar predicate with dimensionality")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( @@ -152,9 +154,101 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH( - inferred_status_error4.status().error_message(), - testing::ContainsRegex("pred operand must have PRED element type")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("pred operand must have PRED element type")); +} + +TEST_F(ShapeInferenceTest, ClampAllMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, + matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampAllScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMinScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMaxScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampOperandScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMinMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMaxMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampOperandMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampBadShapes) { + // Type mismatch + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) + .ok()); + // Dimension mismatch + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_64_, vector_32_, vector_32_) + .ok()); + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_32_, vector_64_, vector_32_) + .ok()); + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + .ok()); } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { @@ -205,8 +299,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { operand_shape_, select_program_shape_, window_, source_shape_fail, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("source shape does not match")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { @@ -216,9 +310,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH( - inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function must take 2 parameters")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { @@ -228,8 +321,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function must have rank-0 PRED")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { @@ -239,8 +332,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function's first parameter")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { @@ -250,8 +343,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function's second parameter")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("select function's second parameter")); } TEST_F(ShapeInferenceTest, Convolve) { @@ -405,8 +498,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { auto inferred_status = ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); ASSERT_FALSE(inferred_status.ok()); - ASSERT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("each dimension exactly once")); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("each dimension exactly once")); } TEST_F(ShapeInferenceTest, MapThatChangesElementType) { @@ -443,43 +536,42 @@ TEST_F(ShapeInferenceTest, Map) { auto no_args_error = ShapeInference::InferMapShape( {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(no_args_error.ok()); - ASSERT_MATCH(no_args_error.status().error_message(), - testing::ContainsRegex("expects at least one argument")); + ASSERT_THAT(no_args_error.status().error_message(), + HasSubstr("expects at least one argument")); auto args_diff_shapes_error = ShapeInference::InferMapShape( {&vector_32_, &vector_64_}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(args_diff_shapes_error.ok()); - ASSERT_MATCH( - args_diff_shapes_error.status().error_message(), - testing::ContainsRegex("requires all operands to have the same shape")); + ASSERT_THAT(args_diff_shapes_error.status().error_message(), + HasSubstr("requires all operands to have the same shape")); auto arity_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); ASSERT_FALSE(arity_error.ok()); - ASSERT_MATCH(arity_error.status().error_message(), - testing::ContainsRegex("function arity must match")); + ASSERT_THAT(arity_error.status().error_message(), + HasSubstr("function arity must match")); auto output_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_)); ASSERT_FALSE(output_shape_error.ok()); - ASSERT_MATCH(output_shape_error.status().error_message(), - testing::ContainsRegex("result has to be a scalar")); + ASSERT_THAT(output_shape_error.status().error_message(), + HasSubstr("result has to be a scalar")); auto param_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_)); ASSERT_FALSE(param_shape_error.ok()); - ASSERT_MATCH(param_shape_error.status().error_message(), - testing::ContainsRegex("parameter has to be a scalar")); + ASSERT_THAT(param_shape_error.status().error_message(), + HasSubstr("parameter has to be a scalar")); auto param_element_type_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, s32_}, f32_)); ASSERT_FALSE(param_element_type_error.ok()); - ASSERT_MATCH(param_element_type_error.status().error_message(), - testing::ContainsRegex("parameter type has to match argument")); + ASSERT_THAT(param_element_type_error.status().error_message(), + HasSubstr("parameter type has to match argument")); Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); @@ -490,26 +582,26 @@ TEST_F(ShapeInferenceTest, Map) { auto inferred_status_error1 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("arity must match number of arguments")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("arity must match number of arguments")); auto inferred_status_error2 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_)); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("has to be a scalar")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("has to be a scalar")); auto inferred_status_error3 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_)); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("has to be a scalar")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("has to be a scalar")); auto inferred_status_error5 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_)); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_MATCH(inferred_status_error5.status().error_message(), - testing::ContainsRegex("parameter type has to match argument")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("parameter type has to match argument")); } TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) { @@ -563,8 +655,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); - EXPECT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("out-of-bounds dimension")); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("out-of-bounds dimension")); } TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { @@ -573,8 +665,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); - EXPECT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("take 2 parameters")); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("take 2 parameters")); } TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { @@ -583,23 +675,50 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); - EXPECT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("first parameter shape differs")); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("first parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = - ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}); + ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1}); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred)); } +TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + auto inferred_status = + ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4}); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred)); +} + +TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + auto inferred_status = + ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4}); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred)); +} + +TEST_F(ShapeInferenceTest, InferInvalidStride) { + Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); + auto inferred_status = + ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1}); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, + inferred_status.status().code()); +} + TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = - ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}); + ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1}); ASSERT_FALSE(inferred_status.ok()); ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, inferred_status.status().code()); @@ -608,7 +727,7 @@ TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) { TEST_F(ShapeInferenceTest, InferSliceShapeRank1) { Shape vector_shape = ShapeUtil::MakeShape(F32, {17}); auto inferred_status = - ShapeInference::InferSliceShape(vector_shape, {2}, {4}); + ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1}); ASSERT_TRUE(inferred_status.ok()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2}))); @@ -726,8 +845,8 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { auto inferred_status = ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); ASSERT_FALSE(inferred_status.ok()); - ASSERT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("dot only supports rank")); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("dot only supports rank")); } // 3D 2D: error @@ -735,8 +854,8 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { auto inferred_status = ShapeInference::InferBinaryOpShape( BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); ASSERT_FALSE(inferred_status.ok()); - ASSERT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("dot only supports rank")); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("dot only supports rank")); } // vector vector -> scalar @@ -848,46 +967,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("automatic")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("automatic")); // broadcast_dimension out of bounds for tensor's rank auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH( - inferred_status_error2.status().error_message(), - testing::ContainsRegex("broadcast dimension number .* too large")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + ContainsRegex("broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("broadcast dimension 0 mismatch")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH( - inferred_status_error4.status().error_message(), - testing::ContainsRegex("size of broadcast_dimensions has to match")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("size of broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_MATCH( - inferred_status_error5.status().error_message(), - testing::ContainsRegex("broadcast dimension number .* too large")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + ContainsRegex("broadcast dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_MATCH(inferred_status_error6.status().error_message(), - testing::ContainsRegex("broadcast dimension 0 mismatch")); + ASSERT_THAT(inferred_status_error6.status().error_message(), + HasSubstr("broadcast dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array @@ -895,14 +1011,14 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); - ASSERT_MATCH(inferred_status_error7.status().error_message(), - testing::ContainsRegex("broadcast dimensions order is wrong")); + ASSERT_THAT(inferred_status_error7.status().error_message(), + HasSubstr("broadcast dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); - ASSERT_MATCH(inferred_status_error8.status().error_message(), - testing::ContainsRegex("broadcast dimensions order is wrong")); + ASSERT_THAT(inferred_status_error8.status().error_message(), + HasSubstr("broadcast dimensions order is wrong")); } // Tests for the while instruction with proper shapes. @@ -927,30 +1043,30 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { auto inferred_status_error1 = ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("condition must take 1 arguments")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("condition must take 1 arguments")); auto bad_shape_2 = ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); auto inferred_status_error2 = ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("body must take 1 arguments")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("body must take 1 arguments")); auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); auto inferred_status_error3 = ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("condition must return a boolean")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("condition must return a boolean")); auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); auto inferred_status_error4 = ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH(inferred_status_error4.status().error_message(), - testing::ContainsRegex("parameter of condition and body")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("parameter of condition and body")); } // Tests for the concatenate instruction with proper shapes. @@ -980,49 +1096,44 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { auto inferred_status_error1 = ShapeInference::InferConcatOpShape({}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH( - inferred_status_error1.status().error_message(), - testing::ContainsRegex("Concatenate expects at least one argument")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("Concatenate expects at least one argument")); auto inferred_status_error2 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex( - "dimension to concatenate along out of bounds: -1")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("dimension to concatenate along out of bounds: -1")); auto inferred_status_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex( - "dimension to concatenate along out of bounds: 1")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("dimension to concatenate along out of bounds: 1")); Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); auto inferred_status_error4 = ShapeInference::InferConcatOpShape( {&vector_32_, &tuple}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH( + ASSERT_THAT( inferred_status_error4.status().error_message(), - testing::ContainsRegex( - "Expected non-tuple argument for operand of concatenation.")); + HasSubstr("Expected non-tuple argument for operand of concatenation.")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_s32}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_MATCH(inferred_status_error5.status().error_message(), - testing::ContainsRegex( - "cannot concatenate arrays with different element types")); + ASSERT_THAT( + inferred_status_error5.status().error_message(), + HasSubstr("cannot concatenate arrays with different element types")); auto inferred_status_error6 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_MATCH( - inferred_status_error6.status().error_message(), - testing::ContainsRegex("cannot concatenate arrays that differ in " - "dimensions other than the one being " - "concatenated")); + ASSERT_THAT(inferred_status_error6.status().error_message(), + HasSubstr("cannot concatenate arrays that differ in " + "dimensions other than the one being " + "concatenated")); } TEST_F(ShapeInferenceTest, Pad) { @@ -1063,27 +1174,27 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { auto inferred_status_error0 = ShapeInference::InferReverseShape(input_shape, {0, 2}); ASSERT_FALSE(inferred_status_error0.ok()); - ASSERT_MATCH(inferred_status_error0.status().error_message(), - testing::ContainsRegex("out-of-bounds")); + ASSERT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("out-of-bounds")); auto inferred_status_error1 = ShapeInference::InferReverseShape(input_shape, {0, -1}); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("out-of-bounds")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("out-of-bounds")); auto inferred_status_error2 = ShapeInference::InferReverseShape(input_shape, {0, 0}); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("duplicated")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("duplicated")); Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); auto inferred_status_error3 = ShapeInference::InferReverseShape(tuple_shape, {0}); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("Expected non-tuple argument")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("Expected non-tuple argument")); } TEST_F(ShapeInferenceTest, Call) { @@ -1103,20 +1214,20 @@ TEST_F(ShapeInferenceTest, Call) { auto inferred_status_error0 = ShapeInference::InferCallShape( {}, ShapeUtil::MakeProgramShape({f32_}, f32_)); EXPECT_FALSE(inferred_status_error0.ok()); - EXPECT_MATCH(inferred_status_error0.status().error_message(), - testing::ContainsRegex("arity must match")); + EXPECT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("arity must match")); auto inferred_status_error1 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({}, f32_)); EXPECT_FALSE(inferred_status_error1.ok()); - EXPECT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("arity must match")); + EXPECT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("arity must match")); auto inferred_status_error2 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_)); EXPECT_FALSE(inferred_status_error2.ok()); - EXPECT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("parameter must match argument")); + EXPECT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("parameter must match argument")); } TEST_F(ShapeInferenceTest, Transpose) { diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index cf49fd72b7d..865be1b84f2 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -73,16 +73,13 @@ ShapedBuffer::MakeUnnestedTupleShapedBuffer( } TF_ASSIGN_OR_RETURN(std::unique_ptr shaped_buffer, MakeShapedBuffer(shape, platform, device_ordinal)); - TF_CHECK_OK(shaped_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElement( - [](const ShapeIndex& index, bool is_leaf, - size_t* buffer_element) -> tensorflow::Status { - if (is_leaf) { - CHECK_EQ(index.size(), 1); - *buffer_element = index[0]; - } - return tensorflow::Status::OK(); - })); + shaped_buffer->mutable_shape_index_to_buffer_entry()->ForEachMutableElement( + [&shaped_buffer](const ShapeIndex& index, size_t* buffer_element) { + if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { + CHECK_EQ(index.size(), 1); + *buffer_element = index[0]; + } + }); shaped_buffer->mutable_buffers()->reserve(buffers.size()); for (const perftools::gputools::DeviceMemoryBase& memory_base : buffers) { shaped_buffer->mutable_buffers()->push_back(memory_base); @@ -126,10 +123,12 @@ ScopedShapedBuffer::MakeScopedShapedBuffer(const Shape& shape, // Allocate an appropriate sized buffer for each array element in the shape. TF_RETURN_IF_ERROR( - shaped_buffer->shape_index_to_buffer_entry_.ForEachMutableElement( - [&shaped_buffer](const ShapeIndex& index, bool is_leaf, - size_t* buffer_entry) -> tensorflow::Status { - if (is_leaf) { + shaped_buffer->shape_index_to_buffer_entry_ + .ForEachMutableElementWithStatus([&shaped_buffer]( + const ShapeIndex& index, + size_t* buffer_entry) + -> tensorflow::Status { + if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { TF_ASSIGN_OR_RETURN( perftools::gputools::DeviceMemoryBase memory_base, shaped_buffer->allocator_->Allocate( diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index c7f6a13023d..4da0a0d3684 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -72,7 +72,7 @@ TransferManager::GetPlatformTransferManagers() { it->second.manager = (*it->second.creation_function)(); } - return it->second.manager; + return it->second.manager.get(); } Status TransferManager::TransferBufferFromDevice( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 7ffce452139..15f6b7bfb4a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -64,6 +65,12 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; + // Transfers the given literal from the Outfeed interface of the device, + // using the given executor. + virtual Status TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) = 0; + // Resets the devices associated with this transfer manager. virtual Status ResetDevices( tensorflow::gtl::ArraySlice @@ -110,7 +117,7 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, int64 size, const void* source, perftools::gputools::DeviceMemoryBase* destination); - typedef TransferManager* (*TransferManagerCreationFunction)(); + typedef std::unique_ptr (*TransferManagerCreationFunction)(); ///// // The TransferManager class also serves as a point to register objects for @@ -140,7 +147,7 @@ class TransferManager { // set up creation_function, and then we use that to lazily create // "manager" the first time GetForPlatform is invoked for a particular id. struct State { - TransferManager* manager = nullptr; + std::unique_ptr manager; TransferManagerCreationFunction creation_function = nullptr; }; diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc index 564111c4f2b..ca38601d919 100644 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/service/transfer_manager_test.cc @@ -121,7 +121,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) { const Shape shape = ShapeUtil::MakeShape(U8, {4}); TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( stream_exec_, memptr, shape, shape, &literal)); - CHECK_EQ("klmn", literal.u8s()); + CHECK_EQ("klmn", literal.u8s_string()); } TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) { diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 07e0ce89f6a..a0c88c6bbc2 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -21,7 +21,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -30,43 +32,55 @@ namespace xla { namespace { -bool IsOperandFoldableToDot(const HloInstruction& hlo) { - return hlo.IsRank2Transpose() && - hlo.user_count() == 1; // The dot is its only user. -} - -bool CanFoldOperandsIntoDot( +TransposeFolding::OperandIndices CanFoldOperandsIntoDot( const HloInstruction& dot, - const TransposeFolding::IsTransposableGemmFn& is_transposable_gemm) { + const TransposeFolding::TransposableGemmOperandsFn& + transposable_gemm_operands) { if (HloOpcode::kDot != dot.opcode()) { - return false; + return {}; } - if (!is_transposable_gemm(dot)) { - return false; + TransposeFolding::OperandIndices operand_set; + for (int64 i = 0; i < dot.operand_count(); ++i) { + auto& operand = *dot.operand(i); + if (operand.IsRank2Transpose() && operand.user_count() == 1) { + operand_set.push_back(i); + } } - const HloInstruction* lhs = dot.operand(0); - const HloInstruction* rhs = dot.operand(1); - bool lhs_foldable = IsOperandFoldableToDot(*lhs); - bool rhs_foldable = IsOperandFoldableToDot(*rhs); - if (!lhs_foldable && !rhs_foldable) { - return false; - } - return true; + return transposable_gemm_operands(dot, operand_set); } +TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( + const HloInstruction& convolution, + const TransposeFolding::TransposableConvOperandsFn& + transposable_conv_operands) { + if (HloOpcode::kConvolution != convolution.opcode()) { + return {}; + } + + // We only support folding the RHS. + const int64 kRhsOperandIndex = 1; + auto& operand = *convolution.operand(kRhsOperandIndex); + if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) { + return transposable_conv_operands(convolution, {kRhsOperandIndex}); + } + + return {}; +} + +using InstructionOperandsPair = + std::pair; + // Folds the operands of `dot` that are foldable transposes. `computation` is -// the parent HLO computation of `dot`. `module` is the parent HloModule of -// `computation`. +// the parent HLO computation of `dot`. // // Returns whether the module is changed. -bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) { +bool FoldTransposeIntoDot(InstructionOperandsPair pair) { + auto* dot = pair.first; std::vector instructions_to_fuse(1, dot); - for (HloInstruction* operand : dot->operands()) { - if (IsOperandFoldableToDot(*operand)) { - instructions_to_fuse.push_back(operand); - } + for (const int64 operand_index : pair.second) { + instructions_to_fuse.push_back(dot->mutable_operand(operand_index)); } // Early-exit if no operands are foldable. @@ -74,33 +88,100 @@ bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) { return false; } - computation->CreateFusionInstruction( + dot->parent()->CreateFusionInstruction( instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); return true; } +// Folds the operands of `convolution` that are foldable transposes. +// `computation` is the parent HLO computation of `convolution`. +// +// Returns whether the module is changed. +bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { + auto& convolution = *pair.first; + + // We only support fusing the RHS transpose into convolution. + // + // ConvolutionDimensionNumbers doesn't make enough of a distinction between + // the output and the activations. + // + // TODO(b/37125184): Support transposing the LHS too. + if (pair.second.size() != 1 || pair.second.front() != 1) { + return false; + } + + const ConvolutionDimensionNumbers& dnums = + convolution.convolution_dimension_numbers(); + HloInstruction& transpose = *convolution.mutable_operand(1); + CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the kernel dimension numbers. We + // need to apply the transpose permutation to the original shape to figure out + // what the new logical dimensions are. + ConvolutionDimensionNumbers new_dnums = dnums; + new_dnums.set_kernel_input_feature_dimension( + transpose_dimensions[dnums.kernel_input_feature_dimension()]); + new_dnums.set_kernel_output_feature_dimension( + transpose_dimensions[dnums.kernel_output_feature_dimension()]); + for (auto& kernel_spatial_dimension : + *new_dnums.mutable_kernel_spatial_dimensions()) { + kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + } + + auto new_conv = HloInstruction::CreateConvolve( + convolution.shape(), convolution.mutable_operand(0), &transpose_operand, + convolution.window(), new_dnums); + TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( + &convolution, std::move(new_conv))); + + return true; +} + } // namespace -TransposeFolding::TransposeFolding(IsTransposableGemmFn is_transposable_gemm) - : is_transposable_gemm_(std::move(is_transposable_gemm)) {} +TransposeFolding::TransposeFolding( + TransposableGemmOperandsFn transposable_gemm_operands, + TransposableConvOperandsFn transposable_conv_operands) + : transposable_gemm_operands_(std::move(transposable_gemm_operands)), + transposable_conv_operands_(std::move(transposable_conv_operands)) {} StatusOr TransposeFolding::Run(HloModule* module) { // Modifying the graph while traversing is dangerous, so we find all folding // opportunities before actually folding them. - HloComputation* entry_computation = module->entry_computation(); - - std::vector foldable_dots; - auto visit_fn = [this, &foldable_dots](HloInstruction* instruction) { - if (CanFoldOperandsIntoDot(*instruction, is_transposable_gemm_)) { - foldable_dots.emplace_back(instruction); + std::vector> foldable_dots; + std::vector> foldable_convolutions; + auto visit_fn = [this, &foldable_dots, + &foldable_convolutions](HloInstruction* instruction) { + { + OperandIndices operand_indices = + CanFoldOperandsIntoDot(*instruction, transposable_gemm_operands_); + if (!operand_indices.empty()) { + foldable_dots.emplace_back(instruction, operand_indices); + } + } + { + OperandIndices operand_indices = CanFoldOperandsIntoConvolution( + *instruction, transposable_conv_operands_); + if (!operand_indices.empty()) { + foldable_convolutions.emplace_back( + std::make_pair(instruction, operand_indices)); + } } return tensorflow::Status::OK(); }; - TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn)); + + for (auto& comp : module->computations()) { + TF_RETURN_IF_ERROR(comp->Accept(visit_fn)); + } bool changed = false; - for (HloInstruction* dot : foldable_dots) { - changed |= FoldTransposeIntoDot(dot, entry_computation); + for (InstructionOperandsPair& pair : foldable_dots) { + changed |= FoldTransposeIntoDot(pair); + } + for (InstructionOperandsPair& pair : foldable_convolutions) { + changed |= FoldTransposeIntoConvolution(pair); } return changed; } diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index d857c04ed8d..71e8446452f 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -25,16 +25,37 @@ namespace xla { // operator is implemented by a GEMM kernel that can transpose its inputs. class TransposeFolding : public HloPassInterface { public: - // IsTransposableGemmFn should return true iff the instruction argument is - // implemented as a GEMM kernel that supports transposing its arguments. - typedef std::function IsTransposableGemmFn; - explicit TransposeFolding(IsTransposableGemmFn is_transposable_gemm); + using OperandIndices = std::vector; + + // Returns the set of foldable operands for a given HLO and some candidate + // operands. + using FoldableOperands = std::function; + using TransposableGemmOperandsFn = FoldableOperands; + using TransposableConvOperandsFn = FoldableOperands; + + // Helper function to explicitly not fold transposes. + static OperandIndices NeverFoldTranspose(const HloInstruction&, + const OperandIndices&) { + return {}; + } + // transposable_gemm_operands returns the set of operands it wants to fold if + // the instruction argument is implemented as a GEMM kernel that supports + // transposing its arguments. + // + // transposable_conv_operands returns the set of operands it wants to fold if + // the instruction argument is implemented as a convolution that supports + // transposing its arguments. + explicit TransposeFolding( + TransposableGemmOperandsFn transposable_gemm_operands, + TransposableConvOperandsFn transposable_conv_operands); tensorflow::StringPiece name() const override { return "transpose-folding"; } StatusOr Run(HloModule* module) override; private: - IsTransposableGemmFn is_transposable_gemm_; + TransposableGemmOperandsFn transposable_gemm_operands_; + TransposableConvOperandsFn transposable_conv_operands_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 09f932e29e6..c72d127ea86 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -16,16 +16,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transpose_folding.h" #include -#include +#include #include +#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -35,12 +38,20 @@ namespace xla { class TransposeFoldingTest : public ::testing::Test { protected: void FoldTranspose(HloModule* module) { - TransposeFolding transpose_folding(gpu::ImplementedAsGemm); + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); EXPECT_IS_OK(transpose_folding.Run(module).status()); } }; -TEST_F(TransposeFoldingTest, FoldTranspose) { +TEST_F(TransposeFoldingTest, FoldDotTranspose) { auto builder = HloComputation::Builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), @@ -61,7 +72,7 @@ TEST_F(TransposeFoldingTest, FoldTranspose) { FoldTranspose(&module); // Instructions after folding: x, y, and the fusion. - std::set instruction_set; + std::unordered_set instruction_set; for (auto& instruction : entry_computation->instructions()) { instruction_set.insert(instruction.get()); } @@ -77,7 +88,7 @@ TEST_F(TransposeFoldingTest, FoldTranspose) { EXPECT_EQ(4, fusion->fused_instructions().size()); } -TEST_F(TransposeFoldingTest, FoldTransposeConstant) { +TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { auto builder = HloComputation::Builder("entry_computation"); // 2x1 HloInstruction* const0 = builder.AddInstruction( @@ -115,7 +126,7 @@ TEST_F(TransposeFoldingTest, FoldTransposeConstant) { entry_computation->root_instruction()->fused_instructions().size()); } -TEST_F(TransposeFoldingTest, FuseWithConstantOperands) { +TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { auto builder = HloComputation::Builder("entry"); // (1.0 + 2.0) * (2.0 - 3.0) HloInstruction* const1 = builder.AddInstruction( @@ -139,11 +150,219 @@ TEST_F(TransposeFoldingTest, FuseWithConstantOperands) { EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); // The arguments to the call should be const1, const2, and const3. - EXPECT_MATCH(call->operands(), testing::UnorderedMatcher( - const1, const2, const3)); + EXPECT_THAT(call->operands(), + ::testing::UnorderedElementsAre(const1, const2, const3)); // The callee should contain 3 parameters and 3 binary operators. EXPECT_EQ(6, callee_computation->instructions().size()); } +TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"y")); + HloInstruction* transpose_y = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, + /*lhs=*/x, /*rhs=*/transpose_y)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(dot)); + + HloInstruction* call = module.OutlineExpressionFromComputation( + {transpose_y, dot}, "outlined", entry_computation); + + FoldTranspose(&module); + + // Instructions after folding: x, y, and the fusion. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(call)) + << "call is not in entry_computation."; + CHECK(instruction_set.empty()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* fusion = + call->called_computations().front()->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + + // The fusion instruction should contain two parameters, one transpose and + // one dot. + EXPECT_EQ(4, fusion->fused_instructions().size()); +} + +// Test that a two dimension swap of the kernel gets folded into convolution. +TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_y = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size( + transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + x->shape(), transpose_y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.kernel_input_feature_dimension(), + new_conv->convolution_dimension_numbers() + .kernel_output_feature_dimension()); + EXPECT_EQ(dnums.kernel_output_feature_dimension(), + new_conv->convolution_dimension_numbers() + .kernel_input_feature_dimension()); +} + +// Test that a complex transpose of the kernel gets folded into convolution. +TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {1, 2, 1, 3}), + /*name=*/"y")); + HloInstruction* transpose_y = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size( + transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + x->shape(), transpose_y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.kernel_input_feature_dimension(), + new_conv->convolution_dimension_numbers() + .kernel_output_feature_dimension()); + EXPECT_EQ(dnums.kernel_spatial_dimensions(1), + new_conv->convolution_dimension_numbers() + .kernel_input_feature_dimension()); + EXPECT_EQ( + dnums.kernel_output_feature_dimension(), + new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(0)); + EXPECT_EQ( + dnums.kernel_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1)); +} + +// Test that a transpose of the activations does not get folded into +// convolution. +TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_x = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + transpose_x->shape(), y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: transpose_x, y, and the convolution. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(transpose_x)) + << "transpose_x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(conv)) + << "transpose_x is not in entry_computation."; + CHECK_EQ(0, instruction_set.size()) + << "entry_computation should contain exactly 4 instructions."; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0e0c0b02e3b..ad6f015c70e 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -33,9 +33,9 @@ limitations under the License. namespace xla { string BufferAlias::ToString() const { - return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "] => ", buffer_->ToString(), ")"); + return tensorflow::strings::StrCat( + "BufferAlias(", instruction_->FullyQualifiedName(), "[", + tensorflow::str_util::Join(index_, ","), "])"); } std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { @@ -45,29 +45,27 @@ std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { bool PointsToSet::IsAmbiguous() const { bool ambiguous = false; - TF_CHECK_OK(ForEachElement( - [&ambiguous](const ShapeIndex& /*index*/, bool /*is_leaf*/, + ForEachElement( + [&ambiguous](const ShapeIndex& /*index*/, const std::vector& points_to) { ambiguous |= points_to.size() > 1; - return Status::OK(); - })); + }); return ambiguous; } bool PointsToSet::IsDistinct() const { bool distinct = true; std::set all_points_to; - TF_CHECK_OK(ForEachElement([&distinct, &all_points_to]( - const ShapeIndex& /*index*/, bool /*is_leaf*/, - const std::vector& points_to) { + ForEachElement([&distinct, &all_points_to]( + const ShapeIndex& /*index*/, + const std::vector& points_to) { for (auto& buffer : points_to) { if (all_points_to.count(buffer) != 0) { distinct = false; } all_points_to.insert(buffer); } - return Status::OK(); - })); + }); return distinct; } @@ -77,29 +75,27 @@ size_t PointsToSet::size() const { return CreateFlattenedSet().size(); } -std::set PointsToSet::CreateFlattenedSet() const { - std::set flat_set; - TF_CHECK_OK(ForEachElement( - [&flat_set](const ShapeIndex& /*index*/, bool /*is_leaf*/, - const std::vector& buffers) { - flat_set.insert(buffers.begin(), buffers.end()); - return Status::OK(); - })); +tensorflow::gtl::FlatSet PointsToSet::CreateFlattenedSet() + const { + tensorflow::gtl::FlatSet flat_set; + ForEachElement([&flat_set](const ShapeIndex& /*index*/, + const std::vector& buffers) { + flat_set.insert(buffers.begin(), buffers.end()); + }); return flat_set; } bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { bool found = false; - TF_CHECK_OK(ForEachElement([&found, &buffer]( - const ShapeIndex& /*index*/, bool /*is_leaf*/, - const std::vector& pointed_to_buffers) { + ForEachElement([&found, &buffer](const ShapeIndex& /*index*/, + const std::vector& + pointed_to_buffers) { if (!found && std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), &buffer) != pointed_to_buffers.end()) { found = true; } - return Status::OK(); - })); + }); return found; } @@ -129,34 +125,32 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, } /* static */ StatusOr> -TuplePointsToAnalysis::Run(const HloModule* module) { +TuplePointsToAnalysis::Run(const HloModule* module, Colorer colorer) { std::unique_ptr analysis( - new TuplePointsToAnalysis(module)); + new TuplePointsToAnalysis(module, std::move(colorer))); TF_RETURN_IF_ERROR(analysis->Analyze()); return std::move(analysis); } +/* static */ StatusOr> +TuplePointsToAnalysis::Run(const HloModule* module) { + return Run(module, DefaultColorer()); +} + Status TuplePointsToAnalysis::Analyze() { points_to_.clear(); for (auto& computation : module_->computations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); + TF_RETURN_IF_ERROR( + PopulateDefinedBuffersAndAliases(computation->instructions())); + // Run points-to analysis on fusion instructions in 'computation'. for (auto& instruction : computation->instructions()) { - TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction( - instruction.get(), &instruction_defined_buffers_[instruction.get()])); - - const PointsToSet& points_to_set = GetPointsToSet(instruction.get()); - TF_RETURN_IF_ERROR(points_to_set.ForEachElement([this, &instruction]( - const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& pointed_to_buffers) { - for (const LogicalBuffer* buffer : pointed_to_buffers) { - if (buffer_aliases_.count(buffer) == 0) { - buffer_aliases_.insert({buffer, std::vector()}); - } - buffer_aliases_[buffer].emplace_back(*buffer, instruction.get(), - index); - } - return Status::OK(); - })); + if (instruction->opcode() != HloOpcode::kFusion) { + continue; + } + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + TF_RETURN_IF_ERROR( + PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); } } @@ -165,11 +159,33 @@ Status TuplePointsToAnalysis::Analyze() { return Status::OK(); } +Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases( + const std::list>& instructions) { + for (auto& instruction : instructions) { + TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction( + instruction.get(), &instruction_defined_buffers_[instruction.get()])); + + const PointsToSet& points_to_set = GetPointsToSet(instruction.get()); + points_to_set.ForEachElement( + [this, &instruction]( + const ShapeIndex& index, + const std::vector& pointed_to_buffers) { + for (const LogicalBuffer* buffer : pointed_to_buffers) { + if (buffer_aliases_.count(buffer) == 0) { + buffer_aliases_.insert({buffer, std::vector()}); + } + buffer_aliases_[buffer].emplace_back(instruction.get(), index); + } + }); + } + return Status::OK(); +} + const LogicalBuffer& TuplePointsToAnalysis::NewLogicalBuffer( HloInstruction* instruction, const ShapeIndex& index) { CHECK_EQ(logical_buffers_.size(), next_buffer_id_); - logical_buffers_.push_back( - MakeUnique(instruction, index, next_buffer_id_)); + logical_buffers_.push_back(MakeUnique( + instruction, index, next_buffer_id_, colorer_(instruction, index))); ++next_buffer_id_; return *logical_buffers_.back(); } @@ -179,13 +195,12 @@ Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) { // contains a single element LogicalBuffer(hlo_instruction, i). This indicates // that this instruction is the source of all buffers in its own output. PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction); - TF_RETURN_IF_ERROR(points_to_set.ForEachMutableElement( - [this, hlo_instruction](const ShapeIndex& index, bool /*is_leaf*/, + points_to_set.ForEachMutableElement( + [this, hlo_instruction](const ShapeIndex& index, std::vector* buffers) { const LogicalBuffer& buffer = NewLogicalBuffer(hlo_instruction, index); buffers->push_back(&buffer); - return Status::OK(); - })); + }); if (ShapeUtil::IsTuple(hlo_instruction->shape())) { // If the hlo instruction is a tuple-shaped, then trivially the instruction @@ -207,24 +222,23 @@ Status TuplePointsToAnalysis::HandleGetTupleElement( // Copy the points-to set (and tuple sources) at index {element_index} of the // operand to the points-to set for this GetTupleElement instruction. - TF_RETURN_IF_ERROR(points_to_set.ForEachMutableElement([&, this]( - const ShapeIndex& target_index, bool /*is_leaf*/, - std::vector* points_to) { - // Construct an index into the operand by prepending element_index to the - // index for the GetTupleElement instruction's points-to set. - ShapeIndex src_index; - src_index.push_back(element_index); - for (auto element : target_index) { - src_index.push_back(element); - } + points_to_set.ForEachMutableElement( + [&, this](const ShapeIndex& target_index, + std::vector* points_to) { + // Construct an index into the operand by prepending element_index to + // the index for the GetTupleElement instruction's points-to set. + ShapeIndex src_index; + src_index.push_back(element_index); + for (auto element : target_index) { + src_index.push_back(element); + } - *points_to = operand_points_to_set.element(src_index); - for (HloInstruction* tuple : - operand_points_to_set.tuple_sources(src_index)) { - points_to_set.add_tuple_source(target_index, tuple); - } - return Status::OK(); - })); + *points_to = operand_points_to_set.element(src_index); + for (HloInstruction* tuple : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(target_index, tuple); + } + }); return Status::OK(); } @@ -265,9 +279,9 @@ Status TuplePointsToAnalysis::HandleTuple( // Copy the points-to set (and tuple sources) of the operand into the // respective subtree of the tuple instructions points-to set. - TF_RETURN_IF_ERROR(operand_points_to_set.ForEachElement( + operand_points_to_set.ForEachElement( [&points_to_set, &operand_points_to_set, i]( - const ShapeIndex& src_index, bool /*is_leaf*/, + const ShapeIndex& src_index, const std::vector& points_to) { ShapeIndex target_index; target_index.push_back(i); @@ -281,8 +295,7 @@ Status TuplePointsToAnalysis::HandleTuple( operand_points_to_set.tuple_sources(src_index)) { points_to_set.add_tuple_source(target_index, tuple); } - return Status::OK(); - })); + }); } points_to_set.add_tuple_source({}, tuple); @@ -303,9 +316,8 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select, // add in elements of the on_false points-to set (tuple sources). PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true); const PointsToSet& false_points_to_set = *FindOrDie(points_to_, on_false); - TF_RETURN_IF_ERROR(points_to_set.ForEachMutableElement( - [&](const ShapeIndex& index, bool /*is_leaf*/, - std::vector* buffers) { + points_to_set.ForEachMutableElement( + [&](const ShapeIndex& index, std::vector* buffers) { for (const LogicalBuffer* false_buffer : false_points_to_set.element(index)) { points_to_set.AddPointedToBuffer(*false_buffer, index); @@ -314,8 +326,7 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select, for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) { points_to_set.add_tuple_source(index, tuple); } - return Status::OK(); - })); + }); // Select creates a new (top-level) buffer to store its result, so its // respective element in the points-to set should contain only itself. @@ -325,12 +336,6 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select, return Status::OK(); } -Status TuplePointsToAnalysis::HandleFusion(HloInstruction* fusion) { - return ShapeUtil::IsTuple(fusion->shape()) - ? Unimplemented("HandleFusion with tuple output") - : DefaultAction(fusion); -} - const PointsToSet& TuplePointsToAnalysis::GetPointsToSet( const HloInstruction* hlo_instruction) const { return *FindOrDie(points_to_, hlo_instruction); @@ -344,7 +349,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( } bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( - HloInstruction* instruction, const ShapeIndex& index) const { + const HloInstruction* instruction, const ShapeIndex& index) const { const std::vector& buffers = GetPointsToSet(instruction).element(index); return (buffers.size() == 1 && buffers[0]->instruction() == instruction); @@ -407,28 +412,29 @@ TuplePointsToAnalysis::GetBuffersDefinedByInstruction( Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction( const HloInstruction* instruction, std::vector* buffers) { - return GetPointsToSet(instruction) - .ForEachElement([this, buffers, instruction]( - const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& source_buffers) { - // Add buffers which 'instruction' is the source of. - CHECK(!source_buffers.empty()); - if (source_buffers.size() == 1 && - source_buffers[0]->instruction() == instruction) { - // If this instruction is the source of this buffer the - // indices must match. - DCHECK(source_buffers[0]->index() == index); - buffers->push_back(source_buffers[0]); - } else { - // If the points-to set includes more than one buffer then - // necessarily this instruction did not produce the - // buffer. - for (const LogicalBuffer* source_buffer : source_buffers) { - DCHECK(source_buffer->instruction() != instruction); - } - } - return Status::OK(); - }); + GetPointsToSet(instruction) + .ForEachElement( + [this, buffers, instruction]( + const ShapeIndex& index, + const std::vector& source_buffers) { + // Add buffers which 'instruction' is the source of. + CHECK(!source_buffers.empty()); + if (source_buffers.size() == 1 && + source_buffers[0]->instruction() == instruction) { + // If this instruction is the source of this buffer the + // indices must match. + DCHECK(source_buffers[0]->index() == index); + buffers->push_back(source_buffers[0]); + } else { + // If the points-to set includes more than one buffer then + // necessarily this instruction did not produce the + // buffer. + for (const LogicalBuffer* source_buffer : source_buffers) { + DCHECK(source_buffer->instruction() != instruction); + } + } + }); + return Status::OK(); } PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( @@ -437,59 +443,67 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( // from src PointsToSet. PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction); const PointsToSet& src_points_to_set = GetPointsToSet(src); - TF_CHECK_OK(dst_points_to_set.ForEachMutableElement( + dst_points_to_set.ForEachMutableElement( [this, &dst_points_to_set, &src_points_to_set]( - const ShapeIndex& index, bool /*is_leaf*/, - std::vector* buffers) { + const ShapeIndex& index, std::vector* buffers) { *buffers = src_points_to_set.element(index); for (auto& tuple_source : src_points_to_set.tuple_sources(index)) { dst_points_to_set.add_tuple_source(index, tuple_source); } - return Status::OK(); - })); + }); return *FindOrDie(points_to_, instruction); } string TuplePointsToAnalysis::ToString() const { string output = tensorflow::strings::Printf( "TuplePointsToSet for module %s:\n", module_->name().c_str()); - for (auto& computation : module_->computations()) { - tensorflow::strings::StrAppend(&output, "computation ", - computation->name().c_str(), ":\n"); + for (const auto& computation : module_->computations()) { + const char* entry = + computation.get() == module_->entry_computation() ? "entry " : ""; + tensorflow::strings::StrAppend(&output, entry, "computation ", + computation->name(), ":\n"); for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { - tensorflow::strings::StrAppend(&output, " instruction ", - instruction->ToShortString(), ":\n"); - const PointsToSet& points_to_set = GetPointsToSet(instruction); - TF_CHECK_OK(points_to_set.ForEachElement( - [&output](const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& points_to) { - tensorflow::strings::StrAppend( - &output, " {", tensorflow::str_util::Join(index, ","), "}: ", - tensorflow::str_util::Join( - points_to, ", ", - [](string* out, const LogicalBuffer* source) { - out->append(source->ToString()); - }), - "\n"); - return Status::OK(); - })); - } - for (auto& buffer : logical_buffers_) { - tensorflow::strings::StrAppend(&output, " buffer ", buffer->ToString(), - ":\n"); - for (const BufferAlias& buffer_alias : buffer_aliases_.at(buffer.get())) { - tensorflow::strings::StrAppend(&output, " alias ", - buffer_alias.ToString(), "\n"); + InstructionToString(instruction, &output); + if (instruction->opcode() == HloOpcode::kFusion) { + for (auto& fused : instruction->fused_instructions()) { + InstructionToString(fused.get(), &output); + } } } } tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n"); - for (const auto& buffer : logical_buffers_) { - tensorflow::strings::StrAppend(&output, " ", buffer->ToString()); + for (auto& buffer : logical_buffers_) { + tensorflow::strings::StrAppend(&output, " buffer ", buffer->ToString(), + ":\n"); + for (const BufferAlias& buffer_alias : buffer_aliases_.at(buffer.get())) { + tensorflow::strings::StrAppend(&output, " alias ", + buffer_alias.ToString(), "\n"); + } } return output; } +void TuplePointsToAnalysis::InstructionToString( + const HloInstruction* instruction, string* output) const { + const string prefix = instruction->IsFused() ? " " : ""; + tensorflow::strings::StrAppend(output, prefix, " instruction ", + instruction->ToShortString(), ":\n"); + const PointsToSet& points_to_set = GetPointsToSet(instruction); + points_to_set.ForEachElement([&prefix, &output]( + const ShapeIndex& index, + const std::vector& + points_to) { + tensorflow::strings::StrAppend( + output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ", + tensorflow::str_util::Join( + points_to, ", ", + [](string* out, const LogicalBuffer* source) { + out->append(source->ToString()); + }), + "\n"); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 7a3eb772d6b..4d7fc7cbc9e 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -65,7 +66,7 @@ class PointsToSet : public ShapeTree> { // Creates a set containing the union of all LogicalBuffers contained in the // PointsToSet. - std::set CreateFlattenedSet() const; + tensorflow::gtl::FlatSet CreateFlattenedSet() const; // Returns true if the given buffer is in the points-to set at the given // index. @@ -116,27 +117,21 @@ class PointsToSet : public ShapeTree> { // value. class BufferAlias { public: - BufferAlias(const LogicalBuffer& buffer, HloInstruction* instruction, - const ShapeIndex& index) - : buffer_(&buffer), instruction_(instruction), index_(index) {} - - // Return the logical buffer aliased at the instruction and index. - const LogicalBuffer& buffer() const { return *buffer_; } + BufferAlias(HloInstruction* instruction, const ShapeIndex& index) + : instruction_(instruction), index_(index) {} // Return the instruction/index of the subshape. HloInstruction* instruction() const { return instruction_; } const ShapeIndex& index() const { return index_; } bool operator==(const BufferAlias& other) const { - return buffer_ == other.buffer_ && instruction_ == other.instruction_ && - index_ == other.index_; + return instruction_ == other.instruction_ && index_ == other.index_; } bool operator!=(const BufferAlias& other) const { return !(*this == other); } string ToString() const; private: - const LogicalBuffer* buffer_; HloInstruction* instruction_; const ShapeIndex index_; }; @@ -147,6 +142,15 @@ std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); // the potential sources of each buffer in each instruction's output. class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { public: + using Colorer = std::function; + + // Runs points-to analysis on 'module' with the provided buffer color + // assigner. + static StatusOr> Run( + const HloModule* module, Colorer colorer); + + // Runs points-to analysis on 'module' with the default color assigner. static StatusOr> Run( const HloModule* module); @@ -185,7 +189,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { const HloInstruction* instruction) const; // Returns true if the given instruction defines a buffer at the given index. - bool InstructionDefinesBufferAtIndex(HloInstruction* instruction, + bool InstructionDefinesBufferAtIndex(const HloInstruction* instruction, const ShapeIndex& index) const; // Returns an OK status if the given buffer is defined by instruction @@ -205,20 +209,34 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { HloInstruction* operand) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; - Status HandleFusion(HloInstruction* fusion) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) override; string ToString() const; + static Colorer DefaultColorer() { + return [](const HloInstruction* instruction, const ShapeIndex& index) { + return LogicalBuffer::Color(0); + }; + } + private: - explicit TuplePointsToAnalysis(const HloModule* module) : module_(module) {} + explicit TuplePointsToAnalysis(const HloModule* module, + Colorer colorer = DefaultColorer()) + : module_(module), colorer_(colorer) {} // Perform the analysis. Should be called immediately after constructing the // object and before calling GetPointsToSet. Status Analyze(); + // Populates instruction-defined buffers and aliases for each instruction + // in 'instructions'. The parameter 'instructions' is passed in a form + // common to how both HloComputation, and fusion instructions maintain a + // list of instructions. + Status PopulateDefinedBuffersAndAliases( + const std::list>& instructions); + // Create a new logical buffer and return a reference to it. The newly created // buffer is stored in an internal vector of LogicalBuffers and can be // accessed with GetBuffer. @@ -239,6 +257,10 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { const HloInstruction* instruction, std::vector* buffers); + // Print points-to set for 'instruction' to 'output'. + void InstructionToString(const HloInstruction* instruction, + string* output) const; + // The module this analysis is performed on. const HloModule* module_; @@ -247,10 +269,11 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { points_to_; // A map containing the LogicalBuffers defined by each HLO instruction. - std::unordered_map> + tensorflow::gtl::FlatMap> instruction_defined_buffers_; - std::unordered_map> + tensorflow::gtl::FlatMap> buffer_aliases_; // All logical buffers in the module, indexed by LogicalBuffer::Id. Keep as @@ -260,6 +283,9 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // The ID of the next logical buffer created. LogicalBuffer::Id next_buffer_id_ = 0; + // Used to color the created logical buffers. + Colorer colorer_; + TF_DISALLOW_COPY_AND_ASSIGN(TuplePointsToAnalysis); }; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index e4dd4d309e5..9909c11929d 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -19,24 +19,41 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { +using ::testing::UnorderedElementsAreArray; +using ::testing::UnorderedElementsAre; + class TuplePointsToAnalysisTest : public HloTestBase { protected: // Builds a module with the given entry computation and runs points to // analysis. void BuildModuleAndRunAnalysis(std::unique_ptr computation) { - module_.reset(new HloModule(TestName())); + BuildModule(std::move(computation)); + RunAnalysis(); + } + + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); } @@ -59,7 +76,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { const std::vector& points_to_set, tensorflow::gtl::ArraySlice buffers) { std::vector vec(buffers.begin(), buffers.end()); - EXPECT_MATCH(points_to_set, testing::UnorderedElementsAre(vec)); + EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec)); } // Checks that the given points-to set contains exactly (unordered) the @@ -76,7 +93,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // Overload which takes a std::set instead of a std::vector. void ExpectHasTopLevelBuffers( - const std::set& points_to_set, + const tensorflow::gtl::FlatSet& points_to_set, tensorflow::gtl::ArraySlice instructions) { ExpectHasTopLevelBuffers(std::vector( points_to_set.begin(), points_to_set.end()), @@ -94,22 +111,16 @@ class TuplePointsToAnalysisTest : public HloTestBase { .ValueOrDie(); std::vector expected_aliases; for (auto& pair : expected) { - expected_aliases.push_back(BufferAlias(*buffer, pair.first, pair.second)); + expected_aliases.push_back(BufferAlias(pair.first, pair.second)); } - EXPECT_MATCH(points_to_analysis_->GetBufferAliases(*buffer), - testing::UnorderedElementsAre(expected_aliases)); + EXPECT_THAT(points_to_analysis_->GetBufferAliases(*buffer), + UnorderedElementsAreArray(expected_aliases)); } std::unique_ptr module_; std::unique_ptr points_to_analysis_; }; -// Expect the given std::set as A contains exactly the given -// HloInstruction*s as __VA_ARGS__. -#define EXPECT_ISET(A, ...) \ - EXPECT_MATCH(testing::SetToVec(A), \ - testing::UnorderedMatcher(__VA_ARGS__)) - TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( @@ -135,8 +146,8 @@ TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size()); EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); - EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), - tuple); + EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), + UnorderedElementsAre(tuple)); ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), @@ -194,9 +205,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(inner_tuple).element({}), {inner_tuple}); - EXPECT_ISET( + EXPECT_THAT( points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}), - inner_tuple); + UnorderedElementsAre(inner_tuple)); EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size()); EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); @@ -204,10 +215,10 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), {constant1, constant2, constant3, inner_tuple, tuple}); - EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), - tuple); - EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}), - inner_tuple); + EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), + UnorderedElementsAre(tuple)); + EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}), + UnorderedElementsAre(inner_tuple)); EXPECT_TRUE( points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty()); @@ -251,7 +262,8 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { {constant1, constant2, inner_tuple}); ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple}); - EXPECT_ISET(points_to_set.tuple_sources({}), inner_tuple); + EXPECT_THAT(points_to_set.tuple_sources({}), + UnorderedElementsAre(inner_tuple)); } TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { @@ -449,8 +461,10 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2}); // Verify tuple sources. - EXPECT_ISET(points_to_set.tuple_sources({}), tuple1, tuple2); - EXPECT_ISET(points_to_set.tuple_sources({0}), inner_tuple1, inner_tuple2); + EXPECT_THAT(points_to_set.tuple_sources({}), + UnorderedElementsAre(tuple1, tuple2)); + EXPECT_THAT(points_to_set.tuple_sources({0}), + UnorderedElementsAre(inner_tuple1, inner_tuple2)); EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size()); EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size()); } @@ -478,8 +492,8 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size()); EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); - EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), - tuple); + EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), + UnorderedElementsAre(tuple)); ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), @@ -540,5 +554,217 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) { ExpectHasBufferAliases(tuple, /*index=*/{}, {{tuple, {}}}); } +class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { + protected: + // Builds a computation, runs instruction fusion HloPass, runs points-to + // analysis, then checks for expected results (see unit test cases for + // example computation graphs). + void Run(const bool add_additional_gte0_user) { + Shape input_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {3}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + Shape tuple_shape = + ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape}); + + auto builder = HloComputation::Builder(TestName()); + // Create tuple-shaped parameter. + auto tuple_param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param0")); + // Create 'tuple_element1' = GetTupleElement(tuple_param0, 1). + auto tuple_element1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1)); + auto ones = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones) + auto update = builder.AddInstruction(HloInstruction::CreateBinary( + update_shape, HloOpcode::kAdd, tuple_element1, ones)); + // Create 'input' = GetTupleElement(tuple_param0, 0). + auto input = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, tuple_param0, 0)); + + if (add_additional_gte0_user) { + // Create 'slice' as an additional user of 'input'. + auto slice = builder.AddInstruction( + HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1})); + // Modify 'update' to take 'slice' output. + update = builder.AddInstruction(HloInstruction::CreateBinary( + update_shape, HloOpcode::kAdd, update, slice)); + } + + // Create slice 'starts' = GetTupleElement(tuple_param0, 2). + auto starts = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2)); + // Update 'input' with 'update' at dynamic 'starts' indices. + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + input_shape, input, update, starts)); + + // Build computation and add it to module as entry computation. + BuildModule(builder.Build()); + // Run instruction fusion HloPass. + EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive) + .Run(module_.get()) + .ValueOrDie()); + // Get computation root instruction (should be a kFusion). + auto* fusion = module_->entry_computation()->root_instruction(); + EXPECT_THAT(fusion, op::Fusion(tuple_param0)); + // Run points-to analysis (should include fused instructions from 'fusion'). + RunAnalysis(); + + // Check points-to set of fusion parameter associated with 'tuple_param0'. + auto* fusion_param = GetFusionParameterForOperand(fusion, tuple_param0); + ExpectHasBuffers( + points_to_analysis_->GetPointsToSet(fusion_param).element({}), + {GetBuffer(fusion_param, {})}); + ExpectHasBuffers( + points_to_analysis_->GetPointsToSet(fusion_param).element({0}), + {GetBuffer(fusion_param, {0})}); + ExpectHasBuffers( + points_to_analysis_->GetPointsToSet(fusion_param).element({1}), + {GetBuffer(fusion_param, {1})}); + ExpectHasBuffers( + points_to_analysis_->GetPointsToSet(fusion_param).element({2}), + {GetBuffer(fusion_param, {2})}); + + // Check that Gte at tuple_index = 0 points-to fusion_param({0}) + auto fused_gte0 = GetUniqueFusionParameterUserAt(fusion_param, 0); + ExpectHasBuffers( + points_to_analysis_->GetPointsToSet(fused_gte0).element({}), + {GetBuffer(fusion_param, {0})}); + // Check that Gte at tuple_index = 1 points-to fusion_param({1}) + auto fused_gte1 = GetUniqueFusionParameterUserAt(fusion_param, 1); + ExpectHasBuffers( + points_to_analysis_->GetPointsToSet(fused_gte1).element({}), + {GetBuffer(fusion_param, {1})}); + // Check that Gte at tuple_index = 2 points-to fusion_param({2}) + auto fused_gte2 = GetUniqueFusionParameterUserAt(fusion_param, 2); + ExpectHasBuffers( + points_to_analysis_->GetPointsToSet(fused_gte2).element({}), + {GetBuffer(fusion_param, {2})}); + + // Check buffer aliases of 'fusion_param' at shape index {0}. + ExpectHasBufferAliases(fusion_param, /*index=*/{0}, + {{fusion_param, {0}}, {fused_gte0, {}}}); + // Check buffer aliases of 'fusion_param' at shape index {1}. + ExpectHasBufferAliases(fusion_param, /*index=*/{1}, + {{fusion_param, {1}}, {fused_gte1, {}}}); + // Check buffer aliases of 'fusion_param' at shape index {2}. + ExpectHasBufferAliases(fusion_param, /*index=*/{2}, + {{fusion_param, {2}}, {fused_gte2, {}}}); + + // Check number of users of 'fusion_param' aliases at shape index {0}. + ExpectNumUsersOfAliases(fusion_param, {0}, + add_additional_gte0_user ? 2 : 1); + } + + // Returns fusion parameter (from 'fusion.fused_instructions') corresponding + // to fusion 'operand'. + HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion, + HloInstruction* operand) { + auto it = std::find_if( + fusion->fused_instructions().begin(), + fusion->fused_instructions().end(), + [=](const std::unique_ptr& fused) { + return fused->opcode() == HloOpcode::kParameter && + fusion->operand(fused->parameter_number()) == operand; + }); + CHECK(it != fusion->fused_instructions().end()); + return (*it).get(); + } + + // Returns all users of 'fusion_paran' at 'tuple_index'. + std::vector GetFusionParameterUsersAt( + HloInstruction* fusion_param, int64 tuple_index) { + CHECK(ShapeUtil::IsTuple(fusion_param->shape())); + std::vector users_at_tuple_index; + for (auto user : fusion_param->users()) { + CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode()); + if (user->tuple_index() == tuple_index) { + users_at_tuple_index.push_back(user); + } + } + return users_at_tuple_index; + } + + // Returns the unique user of 'fusion_param' at 'tuple_index'. + HloInstruction* GetUniqueFusionParameterUserAt(HloInstruction* fusion_param, + int64 tuple_index) { + std::vector users = + GetFusionParameterUsersAt(fusion_param, tuple_index); + CHECK_EQ(1, users.size()); + return users[0]; + } + + // Checks that the count of all users of all aliases of 'instruction' at + // 'index' match 'expected_num_users'. + void ExpectNumUsersOfAliases(const HloInstruction* instruction, + const ShapeIndex& index, + const int64 expected_num_users) { + const auto* buffer = GetBuffer(instruction, index); + int64 num_users = 0; + for (const auto& alias : points_to_analysis_->GetBufferAliases(*buffer)) { + for (auto user : alias.instruction()->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { + // Gte instructions only access the top-level buffer of their operand. + continue; + } + ++num_users; + } + } + EXPECT_EQ(expected_num_users, num_users); + } +}; + +// Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users. +// Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices. +// Tests that there is a single user of the aliases of tuple-shaped fusion +// parameter 0 at shape index {0}. +// +// Param0 Const +// \ / +// Fusion +// / \ +// FusionParam0 FusionParam1 +// / | \ | +// Gte(0) Gte(2) Gte(1) / +// \ | \ / +// \ | Add +// \ | / +// \0 |2 /1 +// DynamicUpdateSlice // fused root. +// +TEST_F(FusionPointsToAnalysisTest, FusionParam0OneUser) { + Run(/*add_additional_gte0_user=*/false); +} + +// Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users. +// Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices. +// Tests that there are two users of the aliases of tuple-shaped fusion +// parameter 0 at shape index {0}. +// +// Param0 Const +// \ / +// Fusion +// / \ +// FusionParam0 FusionParam1 +// / | \ | +// Gte(2) Gte(0) Gte(1) / +// \ | \ / +// \ |\ Add +// \ | \ / +// | | Slice / +// | | \ / +// | | Add +// | | | +// |2 |0 |1 +// DynamicUpdateSlice // fused root. +// +TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { + Run(/*add_additional_gte0_user=*/true); +} + } // namespace } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index db0c1f0369a..4aba8875161 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -17,9 +17,12 @@ limitations under the License. #include #include +#include +#include #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -50,6 +53,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kExp; case UNOP_FLOOR: return HloOpcode::kFloor; + case UNOP_IS_FINITE: + return HloOpcode::kIsFinite; case UNOP_LOG: return HloOpcode::kLog; case UNOP_LOGICAL_NOT: @@ -164,6 +169,9 @@ UserComputation::UserComputation(const string& name, : name_(name), next_handle_value_(1) { *session_computation_.mutable_computation_handle() = handle; session_computation_.set_name(name); + + VLOG(1) << "New UserComputation \"" << name + << "\", handle: " << handle.handle(); } ComputationDataHandle UserComputation::CreateComputationDataHandle() { @@ -198,15 +206,30 @@ StatusOr UserComputation::AddParameterInstruction( parameters_[parameter_number] = &request; + VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << parameter_request.ShortDebugString(); return handle; } Status UserComputation::AddSendInstruction(const SendRequest& send_request) { tensorflow::mutex_lock lock(mutex_); - *session_computation_.add_send_requests() = send_request; // Check if the operand of the instruction is valid. - TF_RETURN_IF_ERROR(LookupRequest(send_request.operand()).status()); + TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status()); + + // No handle is returned, but a handle must be assigned to this instruction + // for computation versioning. + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = ShapeUtil::MakeNil(); + *request.mutable_request()->mutable_send_request() = send_request; + + VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << send_request.ShortDebugString(); return Status::OK(); } @@ -223,6 +246,9 @@ StatusOr UserComputation::AddRecvInstruction( *request.mutable_output_shape() = shape; *request.mutable_request()->mutable_recv_request() = recv_request; + VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << recv_request.ShortDebugString(); return handle; } @@ -231,10 +257,10 @@ StatusOr UserComputation::AddPadInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(pad_request.operand())); + LookUpRequest(pad_request.operand())); TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value, - LookupRequest(pad_request.padding_value())); + LookUpRequest(pad_request.padding_value())); TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape( operand->output_shape(), @@ -248,6 +274,9 @@ StatusOr UserComputation::AddPadInstruction( *request.mutable_output_shape() = inferred_shape; *request.mutable_request()->mutable_pad_request() = pad_request; + VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << pad_request.ShortDebugString(); return handle; } @@ -267,6 +296,8 @@ StatusOr UserComputation::AddConstantInstruction( *request.mutable_output_shape() = validated_shape; *request.mutable_request()->mutable_constant_request() = constant_request; + VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle(); return handle; } @@ -275,7 +306,7 @@ StatusOr UserComputation::AddGetTupleElementInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(get_tuple_element_request.operand())); + LookUpRequest(get_tuple_element_request.operand())); Shape element_shape = ShapeUtil::GetTupleElementShape( operand->output_shape(), get_tuple_element_request.index()); @@ -288,6 +319,9 @@ StatusOr UserComputation::AddGetTupleElementInstruction( *request.mutable_request()->mutable_get_tuple_element_request() = get_tuple_element_request; + VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << get_tuple_element_request.ShortDebugString(); return handle; } @@ -295,10 +329,18 @@ Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) { tensorflow::mutex_lock lock(mutex_); // Verify that the operand index is valid. - TF_RETURN_IF_ERROR(LookupRequest(trace_request.operand()).status()); + TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status()); - *session_computation_.add_trace_requests() = trace_request; + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = ShapeUtil::MakeNil(); + *request.mutable_request()->mutable_trace_request() = trace_request; + VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << trace_request.ShortDebugString(); return Status::OK(); } @@ -331,7 +373,7 @@ StatusOr UserComputation::AddRngInstruction( // Verify that the parameter indices are valid; for (const ComputationDataHandle& param : rng_request.parameter()) { - TF_RETURN_IF_ERROR(LookupRequest(param).status()); + TF_RETURN_IF_ERROR(LookUpRequest(param).status()); } const Shape& validated_shape = rng_request.shape(); TF_RETURN_IF_ERROR( @@ -345,6 +387,9 @@ StatusOr UserComputation::AddRngInstruction( *request.mutable_output_shape() = validated_shape; *request.mutable_request()->mutable_rng_request() = rng_request; + VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << rng_request.ShortDebugString(); return handle; } @@ -355,7 +400,7 @@ StatusOr UserComputation::AddMapInstruction( std::vector operand_shapes; for (const ComputationDataHandle& handle : map_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); operand_shapes.push_back(&operand->output_shape()); } @@ -377,6 +422,9 @@ StatusOr UserComputation::AddMapInstruction( request.add_embedded_computation_versions(to_apply_version); *request.mutable_request()->mutable_map_request() = map_request; + VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << map_request.ShortDebugString(); return handle; } @@ -386,9 +434,9 @@ StatusOr UserComputation::AddReduceInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(reduce_request.operand())); + LookUpRequest(reduce_request.operand())); TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookupRequest(reduce_request.init_value())); + LookUpRequest(reduce_request.init_value())); VersionedComputationHandle::Version to_apply_version = to_apply_computation.version(); @@ -411,6 +459,9 @@ StatusOr UserComputation::AddReduceInstruction( request.add_embedded_computation_versions(to_apply_version); *request.mutable_request()->mutable_reduce_request() = reduce_request; + VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << reduce_request.ShortDebugString(); return handle; } @@ -420,9 +471,9 @@ StatusOr UserComputation::AddReduceWindowInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(reduce_window_request.operand())); + LookUpRequest(reduce_window_request.operand())); TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookupRequest(reduce_window_request.init_value())); + LookUpRequest(reduce_window_request.init_value())); VersionedComputationHandle::Version to_apply_version = to_apply_computation.version(); @@ -446,6 +497,9 @@ StatusOr UserComputation::AddReduceWindowInstruction( *request.mutable_request()->mutable_reduce_window_request() = reduce_window_request; + VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << reduce_window_request.ShortDebugString(); return handle; } @@ -456,11 +510,11 @@ StatusOr UserComputation::AddSelectAndScatterInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(select_and_scatter_request.operand())); + LookUpRequest(select_and_scatter_request.operand())); TF_ASSIGN_OR_RETURN(const OperationRequest* source, - LookupRequest(select_and_scatter_request.source())); + LookUpRequest(select_and_scatter_request.source())); TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookupRequest(select_and_scatter_request.init_value())); + LookUpRequest(select_and_scatter_request.init_value())); VersionedComputationHandle::Version select_version = select_computation.version(); @@ -489,6 +543,9 @@ StatusOr UserComputation::AddSelectAndScatterInstruction( *request.mutable_request()->mutable_select_and_scatter_request() = select_and_scatter_request; + VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << select_and_scatter_request.ShortDebugString(); return handle; } @@ -497,7 +554,7 @@ StatusOr UserComputation::AddReverseInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(reverse_request.operand())); + LookUpRequest(reverse_request.operand())); TF_ASSIGN_OR_RETURN( Shape inferred_shape, ShapeInference::InferReverseShape( @@ -509,6 +566,9 @@ StatusOr UserComputation::AddReverseInstruction( *request.mutable_output_handle() = handle; *request.mutable_output_shape() = inferred_shape; *request.mutable_request()->mutable_reverse_request() = reverse_request; + VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << reverse_request.ShortDebugString(); return handle; } @@ -519,7 +579,7 @@ StatusOr UserComputation::AddWhileInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* init, - LookupRequest(while_request.init())); + LookUpRequest(while_request.init())); VersionedComputationHandle::Version condition_version = condition_computation.version(); @@ -546,6 +606,9 @@ StatusOr UserComputation::AddWhileInstruction( request.add_embedded_computation_versions(body_version); *request.mutable_request()->mutable_while_request() = while_request; + VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << while_request.ShortDebugString(); return handle; } @@ -555,7 +618,7 @@ StatusOr UserComputation::AddBroadcastInstruction( // Fetches and validates the operand. TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(broadcast_request.operand())); + LookUpRequest(broadcast_request.operand())); TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferBroadcastShape( operand->output_shape(), @@ -567,6 +630,10 @@ StatusOr UserComputation::AddBroadcastInstruction( *request.mutable_output_handle() = handle; *request.mutable_output_shape() = inferred_shape; *request.mutable_request()->mutable_broadcast_request() = broadcast_request; + + VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << broadcast_request.ShortDebugString(); return handle; } @@ -576,7 +643,7 @@ StatusOr UserComputation::AddReshapeInstruction( // Fetches and validates the operand. TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(reshape_request.operand())); + LookUpRequest(reshape_request.operand())); TF_ASSIGN_OR_RETURN( Shape inferred_shape, @@ -592,6 +659,36 @@ StatusOr UserComputation::AddReshapeInstruction( *request.mutable_output_shape() = inferred_shape; *request.mutable_request()->mutable_reshape_request() = reshape_request; + VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << reshape_request.ShortDebugString(); + return handle; +} + +StatusOr UserComputation::AddTransposeInstruction( + const TransposeRequest& transpose_request) { + tensorflow::mutex_lock lock(mutex_); + + // Fetches and validates the operand. + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(transpose_request.operand())); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferTransposeShape( + operand->output_shape(), + AsInt64Slice(transpose_request.dimensions()))); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + *request.mutable_request()->mutable_transpose_request() = transpose_request; + + VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << transpose_request.ShortDebugString(); return handle; } @@ -600,13 +697,14 @@ StatusOr UserComputation::AddSliceInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(slice_request.operand())); + LookUpRequest(slice_request.operand())); TF_ASSIGN_OR_RETURN( Shape new_shape, ShapeInference::InferSliceShape( operand->output_shape(), AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()))); + AsInt64Slice(slice_request.limit_indices()), + AsInt64Slice(slice_request.stride()))); ComputationDataHandle handle = CreateComputationDataHandle(); @@ -616,6 +714,9 @@ StatusOr UserComputation::AddSliceInstruction( *request.mutable_output_shape() = new_shape; *request.mutable_request()->mutable_slice_request() = slice_request; + VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << slice_request.ShortDebugString(); return handle; } @@ -624,10 +725,10 @@ StatusOr UserComputation::AddDynamicSliceInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(dynamic_slice_request.operand())); + LookUpRequest(dynamic_slice_request.operand())); TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices, - LookupRequest(dynamic_slice_request.start_indices())); + LookUpRequest(dynamic_slice_request.start_indices())); TF_ASSIGN_OR_RETURN( Shape new_shape, @@ -644,6 +745,9 @@ StatusOr UserComputation::AddDynamicSliceInstruction( *request.mutable_request()->mutable_dynamic_slice_request() = dynamic_slice_request; + VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << dynamic_slice_request.ShortDebugString(); return handle; } @@ -653,14 +757,14 @@ UserComputation::AddDynamicUpdateSliceInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(dynamic_update_slice_request.operand())); + LookUpRequest(dynamic_update_slice_request.operand())); TF_ASSIGN_OR_RETURN(const OperationRequest* update, - LookupRequest(dynamic_update_slice_request.update())); + LookUpRequest(dynamic_update_slice_request.update())); TF_ASSIGN_OR_RETURN( const OperationRequest* start_indices, - LookupRequest(dynamic_update_slice_request.start_indices())); + LookUpRequest(dynamic_update_slice_request.start_indices())); TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferDynamicUpdateSliceShape( @@ -676,6 +780,10 @@ UserComputation::AddDynamicUpdateSliceInstruction( *request.mutable_request()->mutable_dynamic_update_slice_request() = dynamic_update_slice_request; + VLOG(1) << "AddDynamicUpdateSliceInstruction (" + << GetVersionedHandleInternal() << "), data handle " + << handle.handle() << ": " + << dynamic_update_slice_request.ShortDebugString(); return handle; } @@ -685,7 +793,7 @@ StatusOr UserComputation::AddConcatenateInstruction( std::vector operand_shapes; for (const ComputationDataHandle& handle : concatenate_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); operand_shapes.push_back(&operand->output_shape()); } @@ -702,6 +810,9 @@ StatusOr UserComputation::AddConcatenateInstruction( *request.mutable_request()->mutable_concatenate_request() = concatenate_request; + VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << concatenate_request.ShortDebugString(); return handle; } @@ -710,7 +821,7 @@ StatusOr UserComputation::AddConvertInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(convert_request.operand())); + LookUpRequest(convert_request.operand())); TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( operand->output_shape(), @@ -724,6 +835,9 @@ StatusOr UserComputation::AddConvertInstruction( *request.mutable_output_shape() = new_shape; *request.mutable_request()->mutable_convert_request() = convert_request; + VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << convert_request.ShortDebugString(); return handle; } @@ -732,9 +846,9 @@ StatusOr UserComputation::AddConvolveInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookupRequest(convolve_request.lhs())); + LookUpRequest(convolve_request.lhs())); TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookupRequest(convolve_request.rhs())); + LookUpRequest(convolve_request.rhs())); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( lhs->output_shape(), rhs->output_shape(), convolve_request.window(), @@ -748,6 +862,9 @@ StatusOr UserComputation::AddConvolveInstruction( *request.mutable_output_shape() = shape; *request.mutable_request()->mutable_convolve_request() = convolve_request; + VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << convolve_request.ShortDebugString(); return handle; } @@ -756,7 +873,7 @@ StatusOr UserComputation::AddCrossReplicaSumInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(cross_replica_sum_request.operand())); + LookUpRequest(cross_replica_sum_request.operand())); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( operand->output_shape())); @@ -769,6 +886,9 @@ StatusOr UserComputation::AddCrossReplicaSumInstruction( *request.mutable_request()->mutable_cross_replica_sum_request() = cross_replica_sum_request; + VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << cross_replica_sum_request.ShortDebugString(); return handle; } @@ -792,6 +912,9 @@ StatusOr UserComputation::AddInfeedInstruction( *request.mutable_output_shape() = shape; *request.mutable_request()->mutable_infeed_request() = infeed_request; + VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << infeed_request.ShortDebugString(); return handle; } @@ -799,9 +922,29 @@ Status UserComputation::AddOutfeedInstruction( const OutfeedRequest& outfeed_request) { tensorflow::mutex_lock lock(mutex_); - *session_computation_.add_outfeed_requests() = outfeed_request; + const Shape& shape = outfeed_request.shape(); + if (ShapeUtil::IsNestedTuple(shape)) { + return InvalidArgument("Outfeed does not support nested tuple shapes"); + } + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Given shape to Outfeed must have a layout"); + } + // Verify that operand is valid. - TF_RETURN_IF_ERROR(LookupRequest(outfeed_request.operand()).status()); + TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); + + // No handle is returned, but a handle must be assigned to this instruction + // for computation versioning. + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_outfeed_request() = outfeed_request; + + VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << outfeed_request.ShortDebugString(); return Status::OK(); } @@ -812,7 +955,7 @@ StatusOr UserComputation::AddCallInstruction( std::vector operand_shapes; for (const ComputationDataHandle& handle : call_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); operand_shapes.push_back(&operand->output_shape()); } @@ -834,6 +977,9 @@ StatusOr UserComputation::AddCallInstruction( request.add_embedded_computation_versions(to_apply_version); *request.mutable_request()->mutable_call_request() = call_request; + VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << call_request.ShortDebugString(); return handle; } @@ -842,7 +988,7 @@ StatusOr UserComputation::AddCustomCallInstruction( tensorflow::mutex_lock lock(mutex_); for (const ComputationDataHandle& handle : custom_call_request.operands()) { - TF_RETURN_IF_ERROR(LookupRequest(handle).status()); + TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); } const ComputationDataHandle handle = CreateComputationDataHandle(); @@ -854,6 +1000,9 @@ StatusOr UserComputation::AddCustomCallInstruction( *request.mutable_request()->mutable_custom_call_request() = custom_call_request; + VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << custom_call_request.ShortDebugString(); return handle; } @@ -862,7 +1011,7 @@ StatusOr UserComputation::AddUnaryInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookupRequest(unary_request.operand())); + LookUpRequest(unary_request.operand())); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(), operand->output_shape())); @@ -875,6 +1024,9 @@ StatusOr UserComputation::AddUnaryInstruction( *request.mutable_output_shape() = shape; *request.mutable_request()->mutable_unary_op_request() = unary_request; + VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << unary_request.ShortDebugString(); return handle; } @@ -883,9 +1035,9 @@ StatusOr UserComputation::AddBinaryInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookupRequest(binary_request.lhs())); + LookUpRequest(binary_request.lhs())); TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookupRequest(binary_request.rhs())); + LookUpRequest(binary_request.rhs())); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferBinaryOpShape( @@ -900,6 +1052,9 @@ StatusOr UserComputation::AddBinaryInstruction( *request.mutable_output_shape() = shape; *request.mutable_request()->mutable_binary_op_request() = binary_request; + VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << binary_request.ShortDebugString(); return handle; } @@ -908,11 +1063,11 @@ StatusOr UserComputation::AddTernaryInstruction( tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookupRequest(ternary_request.lhs())); + LookUpRequest(ternary_request.lhs())); TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookupRequest(ternary_request.rhs())); + LookUpRequest(ternary_request.rhs())); TF_ASSIGN_OR_RETURN(const OperationRequest* ehs, - LookupRequest(ternary_request.ehs())); + LookUpRequest(ternary_request.ehs())); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTernaryOpShape( ternary_request.triop(), lhs->output_shape(), @@ -926,6 +1081,9 @@ StatusOr UserComputation::AddTernaryInstruction( *request.mutable_output_shape() = shape; *request.mutable_request()->mutable_ternary_op_request() = ternary_request; + VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << ternary_request.ShortDebugString(); return handle; } @@ -935,7 +1093,7 @@ StatusOr UserComputation::AddVariadicInstruction( std::vector operand_shapes; for (const ComputationDataHandle& handle : variadic_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); operand_shapes.push_back(&operand->output_shape()); } @@ -951,16 +1109,35 @@ StatusOr UserComputation::AddVariadicInstruction( *request.mutable_output_shape() = shape; *request.mutable_request()->mutable_variadic_op_request() = variadic_request; + VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << variadic_request.ShortDebugString(); return handle; } StatusOr UserComputation::GetShape(const ComputationDataHandle& handle) { tensorflow::mutex_lock lock(mutex_); - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle)); + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); return operand->output_shape(); } +Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle, + const OpMetadata& metadata) { + tensorflow::mutex_lock lock(mutex_); + + int64 handle_value = handle.handle(); + if (session_computation_.requests().count(handle_value) == 0) { + return InvalidArgument("Invalid handle in SetDebugMetadata (%lld)", + handle_value); + } + *session_computation_.mutable_requests() + ->at(handle_value) + .mutable_request() + ->mutable_metadata() = metadata; + return Status::OK(); +} + Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { tensorflow::mutex_lock lock(mutex_); @@ -970,12 +1147,18 @@ Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { handle_to_return_ = handle; + VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to " + << GetVersionedHandleInternal(); + return Status::OK(); } VersionedComputationHandle UserComputation::GetVersionedHandle() const { tensorflow::mutex_lock lock(mutex_); + return GetVersionedHandleInternal(); +} +VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const { VersionedComputationHandle versioned_handle; versioned_handle.handle = session_computation_.computation_handle(); @@ -1008,12 +1191,62 @@ VersionedComputationHandle::Version UserComputation::version() const { return GetVersionedHandle().version; } +namespace { + +// Returns true if the operation type corresponding to the given opcase can be +// the root of the computation. +bool CanBeRoot(const OpRequest::OpCase& op_case) { + switch (op_case) { + case OpRequest::kTraceRequest: + case OpRequest::kSendRequest: + case OpRequest::kOutfeedRequest: + return false; + default: + return true; + } +} + +// Returns a pointer to the operation with the given data handle value in the +// given SessionComputation. +StatusOr LookUpRequest( + int64 handle_value, const SessionComputation& session_computation) { + if (session_computation.requests().count(handle_value) == 0) { + return InvalidArgument("no ComputationDataHandle value %lld", handle_value); + } + return &session_computation.requests().at(handle_value); +} + +// Returns the OperationRequestion corresponding to the root (result) of the +// session computation. +StatusOr GetRoot( + VersionedComputationHandle::Version version, + const SessionComputation& session_computation) { + TF_RET_CHECK(version > 0); + // Not all instructions can be roots. Walk backwards from the operation + // indicated by this version until a valid root is found. + const OperationRequest* root_request = nullptr; + while (version > 0) { + TF_ASSIGN_OR_RETURN(root_request, + LookUpRequest(version, session_computation)); + if (CanBeRoot(root_request->request().op_case())) { + break; + } + version--; + } + if (version == 0) { + return InternalError("Computation contains no root operation"); + } + return root_request; +} + +} // namespace + StatusOr> UserComputation::ComputeProgramShape( VersionedComputationHandle::Version version) const { tensorflow::mutex_lock lock(mutex_); - CHECK(version > 0 && version < next_handle_value_); + TF_RET_CHECK(version > 0 && version < next_handle_value_); if (program_shape_ == nullptr || program_shape_version_ != version) { // ProgramShape has not been computed yet, or is for different @@ -1042,7 +1275,9 @@ UserComputation::ComputeProgramShape( } // The root determines the output shape. - *program_shape->mutable_result() = GetRoot(version).output_shape(); + TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, + GetRoot(version, session_computation_)); + *program_shape->mutable_result() = root_request->output_shape(); if (ShapeUtil::IsOpaque(program_shape->result())) { return Unimplemented("Computation results cannot be opaque"); } @@ -1279,6 +1514,7 @@ void ConstantVisitor(const SessionComputation& session_computation, is_constant); // TODO(b/32495713): We aren't checking the condition and body // computations themselves. + *is_constant = false; break; } @@ -1294,6 +1530,14 @@ void ConstantVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kTransposeRequest: { + const TransposeRequest& transpose_request = + request.request().transpose_request(); + ConstantVisitor(session_computation, transpose_request.operand(), visited, + is_constant); + break; + } + case OpRequest::kVariadicOpRequest: { const VariadicOpRequest& variadic_op_request = request.request().variadic_op_request(); @@ -1338,7 +1582,7 @@ StatusOr UserComputation::IsConstant( tensorflow::mutex_lock lock(mutex_); // Verify that the handle is valid. - auto operation_status = LookupRequest(handle); + auto operation_status = LookUpRequest(handle); if (!operation_status.ok()) { return operation_status.status(); } @@ -1350,17 +1594,18 @@ StatusOr UserComputation::IsConstant( return is_constant; } -const OperationRequest& UserComputation::GetRoot( - VersionedComputationHandle::Version version) const { - CHECK(version > 0 && version < next_handle_value_); - return session_computation_.requests().at(version); -} - std::vector UserComputation::GetEmbeddedComputations( VersionedComputationHandle::Version version) const { tensorflow::mutex_lock lock(mutex_); + VLOG(1) + << "GetEmbeddedComputations(" << name() << " " + << VersionedComputationHandle{session_computation_.computation_handle(), + version} + << ")"; + XLA_VLOG_LINES(3, session_computation_.DebugString()); + std::vector computations; for (const auto& handle_request : session_computation_.requests()) { int64 handle_value = handle_request.first; @@ -1442,6 +1687,12 @@ UserComputation::GetEmbeddedComputations( } } } + VLOG(2) << "Embedded computations: " + << tensorflow::str_util::Join( + computations, ", ", + [](string* out, const VersionedComputationHandle& h) { + out->append(h.ToString()); + }); return computations; } @@ -1543,7 +1794,7 @@ SessionComputation UserComputation::CloneSessionComputation( return result; } -StatusOr UserComputation::LookupRequest( +StatusOr UserComputation::LookUpRequest( const ComputationDataHandle& handle) const { int64 handle_value = handle.handle(); if (session_computation_.requests().count(handle_value) == 0) { @@ -1594,15 +1845,15 @@ namespace { // DFS order lowering each OperationRequest to an HLO instruction. class ComputationLowerer { public: - static std::unique_ptr Lower( + static StatusOr> Lower( const string& computation_name, const SessionComputation& session_computation, VersionedComputationHandle::Version version, UserComputation::HloComputationResolver hlo_resolver, - bool include_unused_parameters) { + bool include_unreachable_instructions) { ComputationLowerer lowerer(computation_name, session_computation, version, std::move(hlo_resolver)); - return lowerer.Lower(include_unused_parameters); + return lowerer.Lower(include_unreachable_instructions); } private: @@ -1617,13 +1868,20 @@ class ComputationLowerer { // Build an HLO computation from the SessionComputation at the given // version. - std::unique_ptr Lower(bool include_unused_parameters); + StatusOr> Lower( + bool include_unreachable_instructions); private: + // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. + void TraversePostorder( + const ComputationDataHandle& root, + std::unordered_map* visited, + const std::function& visit); + // DFS visitor of the UserComputation operations which lowers the operations // to HLO instructions. - HloInstruction* Visit(const ComputationDataHandle& handle, - std::map* visited); + void Visit(const ComputationDataHandle& handle, + std::unordered_map* instructions); // Resolves a ComputationHandle and Version to a previously lowered // HloComputation using the hlo_resolver_ function. @@ -1631,70 +1889,319 @@ class ComputationLowerer { const ComputationHandle& handle, VersionedComputationHandle::Version version); + // This function takes an input value which is being implicitly broadcast into + // an output shape and figures out the right kBroadcast instruction(s) + // necessary to replicate the implicit broadcast semantics explicitly. + HloInstruction* ImplicitBroadcastToExplicitBroadcast( + HloInstruction* operand, const Shape& output_shape); + HloComputation::Builder hlo_builder_; const SessionComputation& session_computation_; const VersionedComputationHandle::Version version_; const UserComputation::HloComputationResolver hlo_resolver_; }; -std::unique_ptr ComputationLowerer::Lower( - bool include_unused_parameters) { +// Calls 'apply' on each operand of 'request'. +static void ForEachOperand( + const OperationRequest& request, + const std::function& apply) { + switch (request.request().op_case()) { + case OpRequest::kRngRequest: { + const RngRequest& rng_request = request.request().rng_request(); + for (const ComputationDataHandle& param : rng_request.parameter()) { + apply(param); + } + break; + } + + case OpRequest::kConstantRequest: + break; + case OpRequest::kGetTupleElementRequest: { + const GetTupleElementRequest& get_tuple_element_request = + request.request().get_tuple_element_request(); + apply(get_tuple_element_request.operand()); + break; + } + + case OpRequest::kSliceRequest: { + const SliceRequest& slice_request = request.request().slice_request(); + apply(slice_request.operand()); + break; + } + + case OpRequest::kDynamicSliceRequest: { + const DynamicSliceRequest& dynamic_slice_request = + request.request().dynamic_slice_request(); + apply(dynamic_slice_request.operand()); + apply(dynamic_slice_request.start_indices()); + break; + } + + case OpRequest::kDynamicUpdateSliceRequest: { + const DynamicUpdateSliceRequest& dynamic_update_slice_request = + request.request().dynamic_update_slice_request(); + apply(dynamic_update_slice_request.operand()); + apply(dynamic_update_slice_request.update()); + apply(dynamic_update_slice_request.start_indices()); + break; + } + + case OpRequest::kConcatenateRequest: { + const ConcatenateRequest& concatenate_request = + request.request().concatenate_request(); + for (const ComputationDataHandle& handle : + concatenate_request.operands()) { + apply(handle); + } + break; + } + + case OpRequest::kConvolveRequest: { + const ConvolveRequest& convolve_request = + request.request().convolve_request(); + apply(convolve_request.lhs()); + apply(convolve_request.rhs()); + break; + } + + case OpRequest::kCrossReplicaSumRequest: { + const CrossReplicaSumRequest& cross_replica_sum_request = + request.request().cross_replica_sum_request(); + apply(cross_replica_sum_request.operand()); + break; + } + + case OpRequest::kInfeedRequest: + break; + + case OpRequest::kOutfeedRequest: { + const OutfeedRequest& outfeed_request = + request.request().outfeed_request(); + apply(outfeed_request.operand()); + break; + } + + case OpRequest::kMapRequest: { + const MapRequest& map_request = request.request().map_request(); + for (const ComputationDataHandle& handle : map_request.operands()) { + apply(handle); + } + break; + } + + case OpRequest::kReduceRequest: { + const ReduceRequest& reduce_request = request.request().reduce_request(); + apply(reduce_request.operand()); + apply(reduce_request.init_value()); + break; + } + + case OpRequest::kReduceWindowRequest: { + const ReduceWindowRequest& reduce_window_request = + request.request().reduce_window_request(); + apply(reduce_window_request.operand()); + apply(reduce_window_request.init_value()); + break; + } + + case OpRequest::kSelectAndScatterRequest: { + const SelectAndScatterRequest& select_and_scatter_request = + request.request().select_and_scatter_request(); + apply(select_and_scatter_request.operand()); + apply(select_and_scatter_request.source()); + apply(select_and_scatter_request.init_value()); + + break; + } + + case OpRequest::kBroadcastRequest: { + const BroadcastRequest& broadcast_request = + request.request().broadcast_request(); + apply(broadcast_request.operand()); + break; + } + + case OpRequest::kReshapeRequest: { + const ReshapeRequest& reshape_request = + request.request().reshape_request(); + apply(reshape_request.operand()); + break; + } + + case OpRequest::kTransposeRequest: { + const TransposeRequest& transpose_request = + request.request().transpose_request(); + apply(transpose_request.operand()); + break; + } + + case OpRequest::kReverseRequest: { + const ReverseRequest& reverse_request = + request.request().reverse_request(); + apply(reverse_request.operand()); + break; + } + + case OpRequest::kPadRequest: { + const PadRequest& pad_request = request.request().pad_request(); + apply(pad_request.operand()); + apply(pad_request.padding_value()); + break; + } + + case OpRequest::kRecvRequest: + case OpRequest::kParameterRequest: + break; + + case OpRequest::kConvertRequest: { + const ConvertRequest& convert_request = + request.request().convert_request(); + apply(convert_request.operand()); + break; + } + + case OpRequest::kWhileRequest: { + const WhileRequest& while_request = request.request().while_request(); + apply(while_request.init()); + break; + } + + case OpRequest::kTernaryOpRequest: { + const TernaryOpRequest& ternary_op_request = + request.request().ternary_op_request(); + apply(ternary_op_request.lhs()); + apply(ternary_op_request.rhs()); + apply(ternary_op_request.ehs()); + break; + } + + case OpRequest::kVariadicOpRequest: { + const VariadicOpRequest& variadic_op_request = + request.request().variadic_op_request(); + for (const ComputationDataHandle& handle : + variadic_op_request.operands()) { + apply(handle); + } + break; + } + + case OpRequest::kCallRequest: { + const CallRequest& call_request = request.request().call_request(); + for (const ComputationDataHandle& handle : call_request.operands()) { + apply(handle); + } + break; + } + + case OpRequest::kCustomCallRequest: { + const CustomCallRequest& cc_request = + request.request().custom_call_request(); + for (const ComputationDataHandle& operand : cc_request.operands()) { + apply(operand); + } + break; + } + + case OpRequest::kUnaryOpRequest: { + const UnaryOpRequest& unary_op_request = + request.request().unary_op_request(); + apply(unary_op_request.operand()); + break; + } + + case OpRequest::kBinaryOpRequest: { + const BinaryOpRequest& binary_op_request = + request.request().binary_op_request(); + apply(binary_op_request.rhs()); + apply(binary_op_request.lhs()); + break; + } + + case OpRequest::kTraceRequest: { + const TraceRequest& trace_request = request.request().trace_request(); + apply(trace_request.operand()); + break; + } + + case OpRequest::kSendRequest: { + const SendRequest& send_request = request.request().send_request(); + apply(send_request.operand()); + break; + } + + case OpRequest::OP_NOT_SET: + LOG(FATAL) << "OperationRequest doesn't contain a request"; + + default: + LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); + } +} + +void ComputationLowerer::TraversePostorder( + const ComputationDataHandle& root, + std::unordered_map* visited, + const std::function& visit) { + // Stack containing {handle, enter} pairs. The 'enter' value describes whether + // we are entering or leaving 'handle'. + std::stack> work; + work.push({root, true}); + while (!work.empty()) { + ComputationDataHandle handle; + bool enter; + std::tie(handle, enter) = work.top(); + work.pop(); + + if (enter) { + // We are entering 'handle'. The first time we enter 'handle', we add it + // to 'visited' with a nullptr value. If 'handle' is already in 'visited', + // we do not visit it again. This algorithm only uses the presence of + // a handle in 'visited', but we use a map so we can use the same data + // structure to store the HloInstruction outputs. + if (visited->emplace(handle.handle(), nullptr).second) { + const OperationRequest& request = + session_computation_.requests().at(handle.handle()); + // Push the corresponding 'leave' action onto the stack, followed by + // the operands. + work.push({handle, false}); + ForEachOperand(request, [&work](const ComputationDataHandle& child) { + work.push({child, true}); + }); + } + } else { + // We are leaving 'handle'. We have visited the operands of 'handle', and + // now can visit the 'handle' itself. + visit(handle); + } + } +} + +StatusOr> ComputationLowerer::Lower( + bool include_unreachable_instructions) { // Map from ComputationDataHandle to HLO instruction. Serves as a record of // which operations have been visited as well as a cache for looking up // ComputationDataHandles as HloInstructions. - std::map visited; + std::unordered_map instructions; - // A version is simply a ComputationDataHandle of the root of the computation - // at the time the version was generated. Create a ComputationDataHandle with - // this value and pass it to the visitor as the root of the computation to - // lower. - ComputationDataHandle root_handle; - root_handle.set_handle(version_); + TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, + GetRoot(version_, session_computation_)); - HloInstruction* hlo_root = Visit(root_handle, &visited); + auto visit = [&](const ComputationDataHandle& handle) { + Visit(handle, &instructions); + }; + TraversePostorder(root_request->output_handle(), &instructions, visit); + HloInstruction* hlo_root = + instructions.at(root_request->output_handle().handle()); - // A computation may have unused parameters. - if (include_unused_parameters) { + if (include_unreachable_instructions) { + // Iterate through all computation data handles, and visit any unvisited + // operations. for (int64 request_num = 1; request_num <= version_; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - if (request.request().op_case() == OpRequest::kParameterRequest && - visited.count(request.output_handle().handle()) == 0) { - Visit(request.output_handle(), &visited); - } + TF_ASSIGN_OR_RETURN(const OperationRequest* request, + LookUpRequest(request_num, session_computation_)); + TraversePostorder(request->output_handle(), &instructions, visit); } } - // Add trace instructions. - for (const auto& trace_request : session_computation_.trace_requests()) { - if (trace_request.operand().handle() <= version_) { - HloInstruction* operand = visited[trace_request.operand().handle()]; - // Trace instructions cannot be the root of a computation. - HloInstruction* trace_instruction = hlo_builder_.AddInstruction( - HloInstruction::CreateTrace(trace_request.tag(), operand)); - operand->set_tracing(trace_instruction); - } - } - - // Send instructions do not have users, so they are not reachable from the - // root instruction. Therefore, explicitly visit all Send requests (and their - // operand chains) and add to the builder. - for (const auto& send_request : session_computation_.send_requests()) { - Visit(send_request.operand(), &visited); - HloInstruction* operand = visited[send_request.operand().handle()]; - hlo_builder_.AddInstruction(HloInstruction::CreateSend( - operand, send_request.channel_handle().handle())); - } - - // Outfeed instructions do not have users. Explicitly visit all Outfeed - // requests (and their operand chains). - for (const auto& outfeed_request : session_computation_.outfeed_requests()) { - Visit(outfeed_request.operand(), &visited); - HloInstruction* operand = visited[outfeed_request.operand().handle()]; - hlo_builder_.AddInstruction(HloInstruction::CreateOutfeed( - operand, outfeed_request.outfeed_config())); - } - return hlo_builder_.Build(hlo_root); } @@ -1705,24 +2212,62 @@ HloComputation* ComputationLowerer::ResolveComputation( return hlo_resolver_(checked_handle); } -HloInstruction* ComputationLowerer::Visit( - const ComputationDataHandle& handle, - std::map* visited) { - if (visited->count(handle.handle()) != 0) { - return (*visited)[handle.handle()]; +HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( + HloInstruction* operand, const Shape& output_shape) { + CHECK(ShapeUtil::IsScalar(operand->shape()) || + ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)); + Shape broadcast_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), AsInt64Slice(output_shape.dimensions())); + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(operand->shape())) { + return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, operand, AsInt64Slice(broadcast_shape.dimensions()))); } + // Do explicit broadcast for degenerate broadcast. + std::vector broadcast_dimensions; + std::vector reshaped_dimensions; + for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { + if (operand->shape().dimensions(i) > 1) { + broadcast_dimensions.push_back(i); + reshaped_dimensions.push_back(operand->shape().dimensions(i)); + } + } + // Eliminate the size one dimensions. + HloInstruction* reshaped_operand = + hlo_builder_.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(operand->shape().element_type(), + reshaped_dimensions), + operand)); + // Broadcast 'reshape' up to the larger size. + return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, reshaped_operand, broadcast_dimensions)); +} +void ComputationLowerer::Visit( + const ComputationDataHandle& handle, + std::unordered_map* instructions) { + CHECK_LE(handle.handle(), version_); + CHECK(instructions->at(handle.handle()) == nullptr); const OperationRequest& request = session_computation_.requests().at(handle.handle()); + auto add_instruction = [&](std::unique_ptr instruction) { + HloInstruction* hlo_instruction = + hlo_builder_.AddInstruction(std::move(instruction)); + hlo_instruction->set_metadata(request.request().metadata()); + return hlo_instruction; + }; + auto lookup_instruction = [&](const ComputationDataHandle& handle) { + return instructions->at(handle.handle()); + }; HloInstruction* hlo_instruction; switch (request.request().op_case()) { case OpRequest::kRngRequest: { const RngRequest& rng_request = request.request().rng_request(); std::vector parameters; for (const ComputationDataHandle& param : rng_request.parameter()) { - parameters.push_back(Visit(param, visited)); + parameters.push_back(lookup_instruction(param)); } - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRng( + hlo_instruction = add_instruction(HloInstruction::CreateRng( request.output_shape(), rng_request.distribution(), parameters)); break; } @@ -1730,9 +2275,8 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kConstantRequest: { const ConstantRequest& constant_request = request.request().constant_request(); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CloneToUnique(constant_request.literal()))); + hlo_instruction = add_instruction(HloInstruction::CreateConstant( + LiteralUtil::CloneToUnique(Literal(constant_request.literal())))); break; } @@ -1740,35 +2284,34 @@ HloInstruction* ComputationLowerer::Visit( const GetTupleElementRequest& get_tuple_element_request = request.request().get_tuple_element_request(); HloInstruction* operand = - Visit(get_tuple_element_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateGetTupleElement( - request.output_shape(), operand, - get_tuple_element_request.index())); + lookup_instruction(get_tuple_element_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( + request.output_shape(), operand, get_tuple_element_request.index())); break; } case OpRequest::kSliceRequest: { const SliceRequest& slice_request = request.request().slice_request(); - HloInstruction* operand = Visit(slice_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSlice( + HloInstruction* operand = lookup_instruction(slice_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateSlice( request.output_shape(), operand, AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()))); + AsInt64Slice(slice_request.limit_indices()), + AsInt64Slice(slice_request.stride()))); break; } case OpRequest::kDynamicSliceRequest: { const DynamicSliceRequest& dynamic_slice_request = request.request().dynamic_slice_request(); - HloInstruction* operand = Visit(dynamic_slice_request.operand(), visited); + HloInstruction* operand = + lookup_instruction(dynamic_slice_request.operand()); HloInstruction* start_indices = - Visit(dynamic_slice_request.start_indices(), visited); + lookup_instruction(dynamic_slice_request.start_indices()); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateDynamicSlice( - request.output_shape(), operand, start_indices, - AsInt64Slice(dynamic_slice_request.slice_sizes()))); + hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( + request.output_shape(), operand, start_indices, + AsInt64Slice(dynamic_slice_request.slice_sizes()))); break; } @@ -1776,13 +2319,13 @@ HloInstruction* ComputationLowerer::Visit( const DynamicUpdateSliceRequest& dynamic_update_slice_request = request.request().dynamic_update_slice_request(); HloInstruction* operand = - Visit(dynamic_update_slice_request.operand(), visited); + lookup_instruction(dynamic_update_slice_request.operand()); HloInstruction* update = - Visit(dynamic_update_slice_request.update(), visited); + lookup_instruction(dynamic_update_slice_request.update()); HloInstruction* start_indices = - Visit(dynamic_update_slice_request.start_indices(), visited); + lookup_instruction(dynamic_update_slice_request.start_indices()); hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + add_instruction(HloInstruction::CreateDynamicUpdateSlice( request.output_shape(), operand, update, start_indices)); break; } @@ -1793,24 +2336,22 @@ HloInstruction* ComputationLowerer::Visit( std::vector operands; for (const ComputationDataHandle& handle : concatenate_request.operands()) { - HloInstruction* operand = Visit(handle, visited); + HloInstruction* operand = lookup_instruction(handle); operands.push_back(operand); } - hlo_instruction = hlo_builder_.AddInstruction( - HloInstruction::CreateConcatenate(request.output_shape(), operands, - concatenate_request.dimension())); + hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( + request.output_shape(), operands, concatenate_request.dimension())); break; } case OpRequest::kConvolveRequest: { const ConvolveRequest& convolve_request = request.request().convolve_request(); - HloInstruction* lhs = Visit(convolve_request.lhs(), visited); - HloInstruction* rhs = Visit(convolve_request.rhs(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateConvolve( - request.output_shape(), lhs, rhs, convolve_request.window(), - convolve_request.dimension_numbers())); + HloInstruction* lhs = lookup_instruction(convolve_request.lhs()); + HloInstruction* rhs = lookup_instruction(convolve_request.rhs()); + hlo_instruction = add_instruction(HloInstruction::CreateConvolve( + request.output_shape(), lhs, rhs, convolve_request.window(), + convolve_request.dimension_numbers())); break; } @@ -1818,28 +2359,25 @@ HloInstruction* ComputationLowerer::Visit( const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); HloInstruction* operand = - Visit(cross_replica_sum_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), operand)); + lookup_instruction(cross_replica_sum_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( + request.output_shape(), operand)); break; } case OpRequest::kInfeedRequest: { const InfeedRequest& infeed_request = request.request().infeed_request(); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateInfeed( - request.output_shape(), infeed_request.config())); + hlo_instruction = add_instruction(HloInstruction::CreateInfeed( + request.output_shape(), infeed_request.config())); break; } case OpRequest::kOutfeedRequest: { const OutfeedRequest& outfeed_request = request.request().outfeed_request(); - HloInstruction* operand = Visit(outfeed_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateOutfeed( - operand, outfeed_request.outfeed_config())); + HloInstruction* operand = lookup_instruction(outfeed_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( + outfeed_request.shape(), operand, outfeed_request.outfeed_config())); break; } @@ -1847,7 +2385,7 @@ HloInstruction* ComputationLowerer::Visit( const MapRequest& map_request = request.request().map_request(); std::vector operands; for (const ComputationDataHandle& handle : map_request.operands()) { - HloInstruction* operand = Visit(handle, visited); + HloInstruction* operand = lookup_instruction(handle); operands.push_back(operand); } CHECK_EQ(1, request.embedded_computation_versions_size()); @@ -1855,42 +2393,42 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(0); HloComputation* map_computation = ResolveComputation(map_request.to_apply(), map_version); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateMap( + hlo_instruction = add_instruction(HloInstruction::CreateMap( request.output_shape(), operands, map_computation)); break; } case OpRequest::kReduceRequest: { const ReduceRequest& reduce_request = request.request().reduce_request(); - HloInstruction* operand = Visit(reduce_request.operand(), visited); - HloInstruction* init_value = Visit(reduce_request.init_value(), visited); + HloInstruction* operand = lookup_instruction(reduce_request.operand()); + HloInstruction* init_value = + lookup_instruction(reduce_request.init_value()); CHECK_EQ(1, request.embedded_computation_versions_size()); VersionedComputationHandle::Version reduce_version = request.embedded_computation_versions(0); HloComputation* reduce_computation = ResolveComputation(reduce_request.to_apply(), reduce_version); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReduce( - request.output_shape(), operand, init_value, - AsInt64Slice(reduce_request.dimensions()), reduce_computation)); + hlo_instruction = add_instruction(HloInstruction::CreateReduce( + request.output_shape(), operand, init_value, + AsInt64Slice(reduce_request.dimensions()), reduce_computation)); break; } case OpRequest::kReduceWindowRequest: { const ReduceWindowRequest& reduce_window_request = request.request().reduce_window_request(); - HloInstruction* operand = Visit(reduce_window_request.operand(), visited); + HloInstruction* operand = + lookup_instruction(reduce_window_request.operand()); HloInstruction* init_value = - Visit(reduce_window_request.init_value(), visited); + lookup_instruction(reduce_window_request.init_value()); CHECK_EQ(1, request.embedded_computation_versions_size()); VersionedComputationHandle::Version reduce_window_version = request.embedded_computation_versions(0); HloComputation* reduce_window_computation = ResolveComputation( reduce_window_request.to_apply(), reduce_window_version); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReduceWindow( - request.output_shape(), operand, init_value, - reduce_window_request.window(), reduce_window_computation)); + hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( + request.output_shape(), operand, init_value, + reduce_window_request.window(), reduce_window_computation)); break; } @@ -1898,11 +2436,11 @@ HloInstruction* ComputationLowerer::Visit( const SelectAndScatterRequest& select_and_scatter_request = request.request().select_and_scatter_request(); HloInstruction* operand = - Visit(select_and_scatter_request.operand(), visited); + lookup_instruction(select_and_scatter_request.operand()); HloInstruction* source = - Visit(select_and_scatter_request.source(), visited); + lookup_instruction(select_and_scatter_request.source()); HloInstruction* init_value = - Visit(select_and_scatter_request.init_value(), visited); + lookup_instruction(select_and_scatter_request.init_value()); CHECK_EQ(2, request.embedded_computation_versions_size()); VersionedComputationHandle::Version select_version = request.embedded_computation_versions(0); @@ -1912,18 +2450,17 @@ HloInstruction* ComputationLowerer::Visit( select_and_scatter_request.select(), select_version); HloComputation* scatter_computation = ResolveComputation( select_and_scatter_request.scatter(), scatter_version); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateSelectAndScatter( - request.output_shape(), operand, select_computation, - select_and_scatter_request.window(), source, init_value, - scatter_computation)); + hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( + request.output_shape(), operand, select_computation, + select_and_scatter_request.window(), source, init_value, + scatter_computation)); break; } case OpRequest::kBroadcastRequest: { const BroadcastRequest& broadcast_request = request.request().broadcast_request(); - HloInstruction* operand = Visit(broadcast_request.operand(), visited); + HloInstruction* operand = lookup_instruction(broadcast_request.operand()); std::vector broadcast_dimensions; // The client-level broadcast instruction just appends dimensions on the // left (adds lowest numbered dimensions). The HLO broadcast op is more @@ -1932,50 +2469,64 @@ HloInstruction* ComputationLowerer::Visit( // to append dimensions on the left the broadcast_dimensions should just // be the n highest dimension numbers of the output shape where n is // the number of input dimensions. + broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { broadcast_dimensions.push_back(i + ShapeUtil::Rank(request.output_shape()) - ShapeUtil::Rank(operand->shape())); } - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( - request.output_shape(), operand, broadcast_dimensions)); + hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( + request.output_shape(), operand, broadcast_dimensions)); break; } case OpRequest::kReshapeRequest: { const ReshapeRequest& reshape_request = request.request().reshape_request(); - HloInstruction* operand = Visit(reshape_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReshape( - request.output_shape(), - hlo_builder_.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation( - AsInt64Slice(reshape_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(reshape_request.dimensions()))))); + HloInstruction* operand = lookup_instruction(reshape_request.operand()); + HloInstruction* transposed; + if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { + transposed = operand; + } else { + transposed = add_instruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions( + InversePermutation(AsInt64Slice(reshape_request.dimensions())), + operand->shape()), + operand, AsInt64Slice(reshape_request.dimensions()))); + } + hlo_instruction = add_instruction( + HloInstruction::CreateReshape(request.output_shape(), transposed)); + break; + } + + case OpRequest::kTransposeRequest: { + const TransposeRequest& transpose_request = + request.request().transpose_request(); + HloInstruction* operand = lookup_instruction(transpose_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions( + InversePermutation(AsInt64Slice(transpose_request.dimensions())), + operand->shape()), + operand, AsInt64Slice(transpose_request.dimensions()))); break; } case OpRequest::kReverseRequest: { const ReverseRequest& reverse_request = request.request().reverse_request(); - HloInstruction* operand = Visit(reverse_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReverse( - request.output_shape(), operand, - AsInt64Slice(reverse_request.dimensions()))); + HloInstruction* operand = lookup_instruction(reverse_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateReverse( + request.output_shape(), operand, + AsInt64Slice(reverse_request.dimensions()))); break; } case OpRequest::kPadRequest: { const PadRequest& pad_request = request.request().pad_request(); - HloInstruction* operand = Visit(pad_request.operand(), visited); + HloInstruction* operand = lookup_instruction(pad_request.operand()); HloInstruction* padding_value = - Visit(pad_request.padding_value(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreatePad( + lookup_instruction(pad_request.padding_value()); + hlo_instruction = add_instruction(HloInstruction::CreatePad( request.output_shape(), operand, padding_value, pad_request.padding_config())); break; @@ -1983,7 +2534,7 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kRecvRequest: { const RecvRequest& recv_request = request.request().recv_request(); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRecv( + hlo_instruction = add_instruction(HloInstruction::CreateRecv( request.output_shape(), recv_request.channel_handle().handle())); break; } @@ -1991,18 +2542,17 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kParameterRequest: { const ParameterRequest& parameter_request = request.request().parameter_request(); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateParameter( - parameter_request.parameter(), request.output_shape(), - parameter_request.name())); + hlo_instruction = add_instruction(HloInstruction::CreateParameter( + parameter_request.parameter(), request.output_shape(), + parameter_request.name())); break; } case OpRequest::kConvertRequest: { const ConvertRequest& convert_request = request.request().convert_request(); - HloInstruction* operand = Visit(convert_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction( + HloInstruction* operand = lookup_instruction(convert_request.operand()); + hlo_instruction = add_instruction( HloInstruction::CreateConvert(request.output_shape(), operand)); break; } @@ -2018,8 +2568,8 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(1); HloComputation* body = ResolveComputation(while_request.body(), body_version); - HloInstruction* init = Visit(while_request.init(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateWhile( + HloInstruction* init = lookup_instruction(while_request.init()); + hlo_instruction = add_instruction(HloInstruction::CreateWhile( request.output_shape(), condition, body, init)); break; } @@ -2027,13 +2577,12 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); - HloInstruction* lhs = Visit(ternary_op_request.lhs(), visited); - HloInstruction* rhs = Visit(ternary_op_request.rhs(), visited); - HloInstruction* ehs = Visit(ternary_op_request.ehs(), visited); + HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs()); + HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); + HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateTernary( - request.output_shape(), hlo_opcode, lhs, rhs, ehs)); + hlo_instruction = add_instruction(HloInstruction::CreateTernary( + request.output_shape(), hlo_opcode, lhs, rhs, ehs)); break; } @@ -2043,14 +2592,13 @@ HloInstruction* ComputationLowerer::Visit( std::vector operands; for (const ComputationDataHandle& handle : variadic_op_request.operands()) { - HloInstruction* operand = Visit(handle, visited); + HloInstruction* operand = lookup_instruction(handle); operands.push_back(operand); } auto hlo_opcode = VariadicOperationToHloOpcode(variadic_op_request.varop()); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateVariadic( - request.output_shape(), hlo_opcode, operands)); + hlo_instruction = add_instruction(HloInstruction::CreateVariadic( + request.output_shape(), hlo_opcode, operands)); break; } @@ -2058,14 +2606,14 @@ HloInstruction* ComputationLowerer::Visit( const CallRequest& call_request = request.request().call_request(); std::vector operands; for (const ComputationDataHandle& handle : call_request.operands()) { - operands.push_back(Visit(handle, visited)); + operands.push_back(lookup_instruction(handle)); } CHECK_EQ(1, request.embedded_computation_versions_size()); VersionedComputationHandle::Version call_version = request.embedded_computation_versions(0); HloComputation* call_computation = ResolveComputation(call_request.to_apply(), call_version); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateCall( + hlo_instruction = add_instruction(HloInstruction::CreateCall( request.output_shape(), operands, call_computation)); break; } @@ -2075,20 +2623,19 @@ HloInstruction* ComputationLowerer::Visit( request.request().custom_call_request(); std::vector operands; for (const ComputationDataHandle& operand : cc_request.operands()) { - operands.push_back(Visit(operand, visited)); + operands.push_back(lookup_instruction(operand)); } - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateCustomCall( - cc_request.shape(), operands, cc_request.call_target_name())); + hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( + cc_request.shape(), operands, cc_request.call_target_name())); break; } case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); - HloInstruction* operand = Visit(unary_op_request.operand(), visited); + HloInstruction* operand = lookup_instruction(unary_op_request.operand()); auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateUnary( + hlo_instruction = add_instruction(HloInstruction::CreateUnary( request.output_shape(), hlo_opcode, operand)); break; } @@ -2096,8 +2643,8 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kBinaryOpRequest: { const BinaryOpRequest& binary_op_request = request.request().binary_op_request(); - HloInstruction* lhs = Visit(binary_op_request.lhs(), visited); - HloInstruction* rhs = Visit(binary_op_request.rhs(), visited); + HloInstruction* lhs = lookup_instruction(binary_op_request.lhs()); + HloInstruction* rhs = lookup_instruction(binary_op_request.rhs()); auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); if (binary_op_request.broadcast_dimensions_size() > 0) { // Emit a broadcast instruction to perform the "broadcast in dimension" @@ -2116,16 +2663,45 @@ HloInstruction* ComputationLowerer::Visit( // identical to the HLO broadcast semantics so the broadcast_dimensions // field can just be passed to the instruction builder. HloInstruction* broadcasted_operand = - hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( + add_instruction(HloInstruction::CreateBroadcast( broadcast_shape, operand_to_broadcast, AsInt64Slice(binary_op_request.broadcast_dimensions()))); lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateBinary( - request.output_shape(), hlo_opcode, lhs, rhs)); + if (legacy_flags::GetUserComputationFlags() + ->xla_eliminate_hlo_implicit_broadcast) { + if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { + // lhs side is being implicitly broadcast. Change to explicit. + lhs = + ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); + } + + if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { + rhs = + ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); + } + } + hlo_instruction = add_instruction(HloInstruction::CreateBinary( + request.output_shape(), hlo_opcode, lhs, rhs)); + break; + } + + case OpRequest::kTraceRequest: { + const TraceRequest& trace_request = request.request().trace_request(); + HloInstruction* operand = lookup_instruction(trace_request.operand()); + hlo_instruction = add_instruction( + HloInstruction::CreateTrace(trace_request.tag(), operand)); + operand->set_tracing(hlo_instruction); + break; + } + + case OpRequest::kSendRequest: { + const SendRequest& send_request = request.request().send_request(); + HloInstruction* operand = lookup_instruction(send_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateSend( + operand, send_request.channel_handle().handle())); break; } @@ -2135,26 +2711,29 @@ HloInstruction* ComputationLowerer::Visit( default: LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); } - (*visited)[handle.handle()] = hlo_instruction; - return hlo_instruction; + (*instructions)[handle.handle()] = hlo_instruction; } } // namespace StatusOr> UserComputation::BuildHloComputation( VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, bool include_unused_parameters) const { + HloComputationResolver hlo_resolver, + bool include_unreachable_instructions) const { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Building HloComputation from UserComputation " << name_ - << " at version " << version << ". Operation requests:\n" - << session_computation_.ShortDebugString(); + << " at version " << version; + XLA_VLOG_LINES(3, session_computation_.DebugString()); - std::unique_ptr hlo_computation = ComputationLowerer::Lower( - tensorflow::strings::StrCat(name(), ".v", version), session_computation_, - version, std::move(hlo_resolver), include_unused_parameters); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_computation, + ComputationLowerer::Lower( + tensorflow::strings::StrCat(name(), ".v", version), + session_computation_, version, std::move(hlo_resolver), + include_unreachable_instructions)); - VLOG(2) << "HloComputation:\n" << hlo_computation->ToString(); + XLA_VLOG_LINES(2, hlo_computation->ToString()); return std::move(hlo_computation); } diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 2be448466f5..fb5425ae61a 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -144,6 +144,10 @@ class UserComputation { StatusOr AddReshapeInstruction( const ReshapeRequest& reshape_request); + // Enqueues a transpose instruction onto this user computation. + StatusOr AddTransposeInstruction( + const TransposeRequest& transpose_request); + // Enqueues a slice instruction onto this user computation. StatusOr AddSliceInstruction( const SliceRequest& slice_request); @@ -236,20 +240,24 @@ class UserComputation { // Returns the output shape of the operation indicated by the given handle. StatusOr GetShape(const ComputationDataHandle& handle); + // Sets metadata on the Hlo instruction referenced by the given handle. + Status SetOpMetadata(const ComputationDataHandle& handle, + const OpMetadata& metadata); + // Builds a HLO computation from the UserComputation. The parameter "resolver" // is a function which returns a pointer to the HloComputation corresponding // to the given ComputationHandle at the given version. The resolver is used // for operations, such as map, which call other computations and need a // pointer to the called HloComputation to construct the respective HLO - // instructions. If include_unused_computation is true, then all parameter - // instructions are lowered into HloInstructions even if the parameter is - // unused (the root of the computation is unreachable from the parameter). + // instructions. If include_unreachable_instructions is true, then + // instructions which are not reachable from the root are lowered into + // HloInstructions. using HloComputationResolver = std::function; StatusOr> BuildHloComputation( VersionedComputationHandle::Version version, HloComputationResolver hlo_resolver, - bool include_unused_parameters = true) const; + bool include_unreachable_instructions = true) const; // Return a vector containing the embedded computations used by this // UserComputation. Only embedded computations which are called directly by @@ -285,13 +293,8 @@ class UserComputation { const std::map& old_to_new) EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Returns the OperationRequestion corresponding to the root (result) of the - // computation. - const OperationRequest& GetRoot(VersionedComputationHandle::Version version) - const EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Returns the OperationRequest corresponding to the given handle value. - StatusOr LookupRequest( + // Returns the OperationRequest corresponding to the given handle. + StatusOr LookUpRequest( const ComputationDataHandle& handle) const EXCLUSIVE_LOCKS_REQUIRED(mutex_); @@ -305,6 +308,9 @@ class UserComputation { VersionedComputationHandle::Version version) const EXCLUSIVE_LOCKS_REQUIRED(mutex_); + VersionedComputationHandle GetVersionedHandleInternal() const + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + // Name of the computation. string name_; diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc new file mode 100644 index 00000000000..ea691201263 --- /dev/null +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -0,0 +1,282 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/user_computation.h" + +#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using UserComputationTest = ::testing::Test; + +TEST_F(UserComputationTest, SimpleComputation) { + const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); + const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2}); + + // Build a simple three operation computatation: + // + // %constant = Constant({123, 42}) + // %param = Param(0) + // %outfeed = Outfeed(%constant) + // + // Build the computation at two different versions and check invariants. + ComputationHandle handle; + handle.set_handle(123); + UserComputation computation("TheComputation", handle); + + ConstantRequest constant_request; + *constant_request.mutable_literal() = + LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); + TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle, + computation.AddConstantInstruction(constant_request)); + + ParameterRequest param_request; + *param_request.mutable_shape() = kScalarShape; + param_request.set_parameter(0); + param_request.set_name("param0"); + TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle param_handle, + computation.AddParameterInstruction(param_request)); + OpMetadata metadata; + metadata.set_op_name("meta"); + TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); + + OutfeedRequest outfeed_request; + *outfeed_request.mutable_operand() = constant_handle; + outfeed_request.set_outfeed_config("abc"); + TF_ASSERT_OK(computation.AddOutfeedInstruction(outfeed_request)); + + auto hlo_resolver = [](const VersionedComputationHandle& handle) { + return nullptr; + }; + { + // Test the computation at the latest version. In this case, the most + // recently added operation is an outfeed. However, the outfeed is not the + // root because outfeeds cannot be the root of a computation. + VersionedComputationHandle latest_version = + computation.GetVersionedHandle(); + + // Program shape should have a single scalar parameter and scalar + // result. The outfeed instruction should not affect the program shape. + TF_ASSIGN_OR_ASSERT_OK( + std::shared_ptr program_shape, + computation.ComputeProgramShape(latest_version.version)); + ASSERT_EQ(1, program_shape->parameters_size()); + EXPECT_TRUE( + ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); + EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); + + // Build the HLO computation. + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr hlo_computation, + computation.BuildHloComputation(latest_version.version, hlo_resolver)); + // There should be one HloInstruction per UserComputation operation. + EXPECT_EQ(3, hlo_computation->instruction_count()); + // The root of the instruction should be the parameter instruction (not the + // outfeed). + EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); + } + + { + // Test the computation at the version right after the parameter instruction + // is added. + VersionedComputationHandle version_at_param = + computation.GetVersionedHandleAtOperation(param_handle); + + // Program shape should have a single scalar parameter, and scalar result. + TF_ASSIGN_OR_ASSERT_OK( + std::shared_ptr program_shape, + computation.ComputeProgramShape(version_at_param.version)); + ASSERT_EQ(1, program_shape->parameters_size()); + EXPECT_TRUE( + ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); + EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); + + // There should be two instructions, one for the constant and one for the + // parameter. The outfeed instruction should not be included. + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr hlo_computation, + computation.BuildHloComputation( + version_at_param.version, hlo_resolver)); + EXPECT_EQ(2, hlo_computation->instruction_count()); + EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); + } + { + // Test the computation at the latest version, but lowered with + // include_unreachable_instructions set to false. + VersionedComputationHandle latest_version = + computation.GetVersionedHandle(); + + // Build the HLO computation. + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr hlo_computation, + computation.BuildHloComputation( + latest_version.version, hlo_resolver, + /*include_unreachable_instructions=*/false)); + // There is only one reachable instruction, the parameter. + EXPECT_EQ(1, hlo_computation->instruction_count()); + // The root of the instruction should be the parameter instruction (not the + // outfeed). + EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); + EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), + "meta"); + } +} + +TEST_F(UserComputationTest, EliminateScalarBroadcast) { + if (!legacy_flags::GetUserComputationFlags() + ->xla_eliminate_hlo_implicit_broadcast) { + return; + } + + // Build a binary computation with scalar broadcast. + // + // %a = Constant({123, 42}) + // %b = Constant(1) + // %add = Add(%a, %b) + ComputationHandle handle; + handle.set_handle(123); + UserComputation computation("TheComputation", handle); + + ConstantRequest a_request; + *a_request.mutable_literal() = + LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); + TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, + computation.AddConstantInstruction(a_request)); + + ConstantRequest b_request; + *b_request.mutable_literal() = LiteralUtil::CreateR0(1.0f)->ToProto(); + TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, + computation.AddConstantInstruction(b_request)); + + BinaryOpRequest add; + add.set_binop(BINOP_ADD); + *add.mutable_lhs() = a_handle; + *add.mutable_rhs() = b_handle; + TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); + + auto hlo_resolver = [](const VersionedComputationHandle& handle) { + return nullptr; + }; + VersionedComputationHandle latest_version = computation.GetVersionedHandle(); + + // Build the HLO computation. + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr hlo_computation, + computation.BuildHloComputation(latest_version.version, hlo_resolver)); + // The binary operation has implicit scalar broadcast, should be converted + // to an explicit broadcast intruction and a binary instruction. + EXPECT_EQ(4, hlo_computation->instruction_count()); + EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); + const auto& operands = hlo_computation->root_instruction()->operands(); + ASSERT_EQ(2, operands.size()); + EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast || + operands[1]->opcode() == HloOpcode::kBroadcast); +} + +TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { + if (!legacy_flags::GetUserComputationFlags() + ->xla_eliminate_hlo_implicit_broadcast) { + return; + } + + // Build a binary computation with in-dim broadcast and degenerate broadcast. + // + // %a = Param({2, 3}); + // %b = Param({2, 1, 4}); + // %add = Add(%a, %b, {0, 1}); + ComputationHandle handle; + handle.set_handle(123); + UserComputation computation("TheComputation", handle); + + ParameterRequest a_request; + *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); + a_request.set_name("a"); + a_request.set_parameter(0); + TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, + computation.AddParameterInstruction(a_request)); + + ParameterRequest b_request; + *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); + b_request.set_name("b"); + b_request.set_parameter(1); + TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, + computation.AddParameterInstruction(b_request)); + + BinaryOpRequest add; + add.set_binop(BINOP_ADD); + *add.mutable_lhs() = a_handle; + *add.mutable_rhs() = b_handle; + add.add_broadcast_dimensions(0); + add.add_broadcast_dimensions(1); + TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); + + auto hlo_resolver = [](const VersionedComputationHandle& handle) { + return nullptr; + }; + VersionedComputationHandle latest_version = computation.GetVersionedHandle(); + + // Build the HLO computation. + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr hlo_computation, + computation.BuildHloComputation(latest_version.version, hlo_resolver)); + + // The binary operation has in-dim broadcast and degenerate broadcast, should + // first do the in-dim broadcast then convert the degnerate broadcast into a + // reshape and a broadcast. + // + // b a + // | | + // broadcast reshape + // | | + // | broadcast + // \ / + // add + EXPECT_EQ(6, hlo_computation->instruction_count()); + EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); + const auto& operands = hlo_computation->root_instruction()->operands(); + ASSERT_EQ(2, operands.size()); + EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast && + operands[1]->opcode() == HloOpcode::kBroadcast); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendUserComputationFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.cc b/tensorflow/compiler/xla/service/versioned_computation_handle.cc new file mode 100644 index 00000000000..a693c4695f0 --- /dev/null +++ b/tensorflow/compiler/xla/service/versioned_computation_handle.cc @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" + +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +string VersionedComputationHandle::ToString() const { + return tensorflow::strings::StrCat(handle.handle(), ":v", version); +} + +std::ostream& operator<<(std::ostream& out, + const VersionedComputationHandle& versioned_handle) { + out << versioned_handle.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h index 03bee3d4a5f..5732a56caff 100644 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.h +++ b/tensorflow/compiler/xla/service/versioned_computation_handle.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ +#include + +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/types.h" namespace xla { @@ -32,6 +34,8 @@ struct VersionedComputationHandle { ComputationHandle handle; Version version; + + string ToString() const; bool operator==(const VersionedComputationHandle& other) const { return (handle.handle() == other.handle.handle()) && (version == other.version); @@ -43,6 +47,9 @@ struct VersionedComputationHandle { } }; +std::ostream& operator<<(std::ostream& out, + const VersionedComputationHandle& versioned_handle); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index fc107480f73..809941d8fe1 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -21,7 +21,10 @@ limitations under the License. namespace xla { -// Defines the interface for an XLA service. +// Defines the interface for an XLA service on the client side. This service +// helps abstract around the actual implementation of a service - the service +// can be local (running in the same process), or remote - in which case an RPC +// stub is used as the implementation. class ServiceInterface { public: ServiceInterface() {} @@ -31,23 +34,19 @@ class ServiceInterface { virtual tensorflow::Status TransferToClient( const TransferToClientRequest* arg, TransferToClientResponse* result) = 0; - virtual tensorflow::Status TransferToClientInProcess( - const TransferToClientInProcessRequest* arg, - TransferToClientInProcessResponse* result) = 0; - virtual tensorflow::Status TransferToServer( const TransferToServerRequest* arg, TransferToServerResponse* result) = 0; virtual tensorflow::Status TransferToInfeed( const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0; + virtual tensorflow::Status TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) = 0; + virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; - virtual tensorflow::Status TransferToServerInProcess( - const TransferToServerInProcessRequest* arg, - TransferToServerInProcessResponse* result) = 0; - virtual tensorflow::Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* result) = 0; diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 6963a68d10d..cc456df4fce 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -33,22 +33,67 @@ limitations under the License. namespace xla { +namespace internal { + +// Internal representation of each node in a ShapeTree. +template +struct ShapeTreeNode { + // Data corresponding to this node. + T data; + + // Children of this node. + std::vector> children; + + ShapeTreeNode() = default; + explicit ShapeTreeNode(const T& data) : data(data) {} + + ShapeTreeNode(const ShapeTreeNode& other) + : data(other.data), children(other.children.size()) { + for (size_t i = 0; i < children.size(); ++i) { + children[i] = MakeUnique(*other.children[i]); + } + } + + ShapeTreeNode& operator=(const ShapeTreeNode& other) { + if (this != &other) { + data = other.data; + children.resize(other.children.size()); + for (size_t i = 0; i < children.size(); ++i) { + children[i] = MakeUnique(*other.children[i]); + } + } + return *this; + } +}; + +} // namespace internal + // A ShapeTree is a recursive data structure which mirrors the structure of a -// XLA shape and holds a value of type T for each array in the shape. For -// array shapes, a ShapeTree trivially holds a single value of type T. For tuple -// shapes which can be an arbitrary tree with arrays at the leaves, a ShapeTree -// is an identically structured tree with data elements of type T at the leaves. +// XLA shape and holds a value of type T for each subshape (i.e. tuple or array) +// in the shape. For array shapes, a ShapeTree trivially holds a single value of +// type T. +// +// For tuple shapes which can be an arbitrary tree with arrays at the leaves, a +// ShapeTree is an identically structured tree with data elements of type T at +// every node. I.e. the root is a tuple by definition, all interior nodes are +// also tuples, and all leaves are arrays. // // Like the Shape data structure, this is a tree and tuple elements cannot be -// duplicated. That is, every distinct element position in the Shape has a -// unique T object. +// duplicated. That is, every distinct ShapeIndex in the Shape has a unique T +// object. template class ShapeTree { public: + // Default constructor creates a tree with a nil shape (i.e. an empty tuple). + ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} + // Create ShapeTree with the given shape, and default-constructed T values for + // all nodes. explicit ShapeTree(const Shape& shape); + // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(const Shape& shape, const T& init_value); - ShapeTree(const ShapeTree& other); - ShapeTree& operator=(const ShapeTree& other); + + ShapeTree(const ShapeTree& other) = default; + ShapeTree& operator=(const ShapeTree& other) = default; // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). @@ -56,12 +101,12 @@ class ShapeTree { T* mutable_element(const ShapeIndex& index); // Return the shape represented with this ShapeTree. - const Shape& shape() const { return *shape_; } + const Shape& shape() const { return shape_; } // Returns true if the node at the given index is a leaf node (an array // shape). bool IsLeaf(const ShapeIndex& index) const { - return Lookup(index).elements_.empty(); + return Lookup(index)->children.empty(); } // Recursively traverses the shape and calls the given function at each @@ -69,190 +114,235 @@ class ShapeTree { // // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. - // is_leaf : Whether this element is a leaf element in the shape. That is, - // whether this index corresponds to an array and not a (nested) - // tuple element. // data : The data value at this elemnt. - // - // If any call to the given function returns a non-OK status, then traversal - // is aborted and the status value is returned. - using VisitorFunction = std::function; - tensorflow::Status ForEachElement(VisitorFunction func) const; + using VisitorFunction = + std::function; + void ForEachElement(const VisitorFunction& func) const; - using MutableVisitorFunction = std::function; - tensorflow::Status ForEachMutableElement(MutableVisitorFunction func); + using MutableVisitorFunction = + std::function; + void ForEachMutableElement(const MutableVisitorFunction& func); + + // Variants of ForEach(Mutable)Element which propagate a Status value from the + // visitor. + using StatusVisitorFunction = + std::function; + Status ForEachElementWithStatus(const StatusVisitorFunction& func) const; + + using MutableStatusVisitorFunction = + std::function; + Status ForEachMutableElementWithStatus( + const MutableStatusVisitorFunction& func); + + // Copy the subtree of values from 'other' rooted at ShapeIndex + // 'source_base_index' into the subtree of value in this ShapeTree rooted at + // 'target_base_index'. + // + // Precondition: The subshape of other.shape() at index source_base_index must + // be compatible with the subshape of shape() at index target_base_index. + void CopySubtreeFrom(const ShapeTree& other, + const ShapeIndex& source_base_index, + const ShapeIndex& target_base_index); + + bool operator==(const ShapeTree& other) const; + bool operator!=(const ShapeTree& other) const { return !(*this == other); } private: - // Private default constructor for non-root nodes of the tree. - ShapeTree() = default; + using Node = internal::ShapeTreeNode; + + // Initialize node->children based on 'shape'. All children are assigned the + // the given 'init_value'. + void InitChildren(const Shape& shape, const T& init_value, Node* node); + + // Initialize node->children based on 'shape'. All children have + // default-constructed data values. + void InitChildren(const Shape& shape, Node* node); // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). - static tensorflow::Status ForEachHelperMutable(ShapeIndex* index, - ShapeTree* shape_tree, - MutableVisitorFunction func); - static tensorflow::Status ForEachHelper(ShapeIndex* index, - const ShapeTree& shape_tree, - VisitorFunction func); - - // Copy all the data elements (of type T) from "other" into "this". "this" - // must have the same tree structure as "other" prior to calling this method. - void CopyDataElements(const ShapeTree& other); - - // Recursive helper for constructing a subtree beneath "this" node. - void BuildTree(const Shape& shape); + static Status ForEachHelper(const StatusVisitorFunction& func, + const Node& node, ShapeIndex* index); + static Status ForEachMutableHelper(const MutableStatusVisitorFunction& func, + Node* node, ShapeIndex* index); // Return the tree node at the given index. - ShapeTree& Lookup(const ShapeIndex& index); - const ShapeTree& Lookup(const ShapeIndex& index) const; + Node* Lookup(const ShapeIndex& index); + const Node* Lookup(const ShapeIndex& index) const; - // The data corresponding to the array at this node. - T data_; + // The root node, which contains all other nodes. + Node root_; - // The XLA shape mirrored in this ShapeTree. Only the root of the - // ShapeTree has this member set. - std::unique_ptr shape_; - - // The children of this node in the tree. - std::vector> elements_; + // The XLA shape mirrored in this ShapeTree. + Shape shape_; }; template -void ShapeTree::BuildTree(const Shape& shape) { +void ShapeTree::InitChildren(const Shape& shape, const T& init_value, + Node* node) { if (ShapeUtil::IsTuple(shape)) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - elements_.emplace_back(new ShapeTree()); - elements_.back()->BuildTree(shape.tuple_shapes(i)); + node->children.emplace_back(new Node(init_value)); + InitChildren(shape.tuple_shapes(i), init_value, + node->children.back().get()); } } } template -ShapeTree::ShapeTree(const Shape& shape) : shape_(MakeUnique(shape)) { - // The shape_ field is just used to hold the structure of the shape. It should - // not be relied upon to store layout information. - LayoutUtil::ClearLayout(shape_.get()); - BuildTree(*shape_); +void ShapeTree::InitChildren(const Shape& shape, Node* node) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + node->children.emplace_back(new Node()); + InitChildren(shape.tuple_shapes(i), node->children.back().get()); + } + } +} + +template +ShapeTree::ShapeTree(const Shape& shape) : root_(), shape_(shape) { + // The shape_ field is just used to hold the structure of the shape. + // It should not be relied upon to store layout information. + LayoutUtil::ClearLayout(&shape_); + InitChildren(shape_, &root_); } template ShapeTree::ShapeTree(const Shape& shape, const T& init_value) - : shape_(MakeUnique(shape)) { - LayoutUtil::ClearLayout(shape_.get()); - BuildTree(*shape_); - TF_CHECK_OK(ForEachMutableElement( - [&init_value](const ShapeIndex& /*index*/, bool /*is_leaf*/, bool* data) { - *data = init_value; - return tensorflow::Status::OK(); - })); -} - -template -ShapeTree::ShapeTree(const ShapeTree& other) - : shape_(MakeUnique(other.shape())) { - LayoutUtil::ClearLayout(shape_.get()); - BuildTree(*shape_); - CopyDataElements(other); -} - -template -ShapeTree& ShapeTree::operator=(const ShapeTree& other) { - if (this == &other) { - return *this; - } - elements_.clear(); - shape_ = MakeUnique(other.shape()); - LayoutUtil::ClearLayout(shape_.get()); - - BuildTree(*shape_); - CopyDataElements(other); - return *this; -} - -template -void ShapeTree::CopyDataElements(const ShapeTree& other) { - CHECK(ShapeUtil::Compatible(shape(), other.shape())); - TF_CHECK_OK(ForEachMutableElement( - [&other](const ShapeIndex& index, bool /*is_leaf*/, T* data) { - *data = other.element(index); - return tensorflow::Status::OK(); - })); + : root_(init_value), shape_(shape) { + // The shape_ field is just used to hold the structure of the shape. + // It should not be relied upon to store layout information. + LayoutUtil::ClearLayout(&shape_); + InitChildren(shape_, init_value, &root_); } template const T& ShapeTree::element(const ShapeIndex& index) const { - return Lookup(index).data_; + return Lookup(index)->data; } template T* ShapeTree::mutable_element(const ShapeIndex& index) { - return &Lookup(index).data_; + return &Lookup(index)->data; } template -ShapeTree& ShapeTree::Lookup(const ShapeIndex& index) { - ShapeTree* node = this; - for (auto& i : index) { +internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { + Node* node = &root_; + for (const int64 i : index) { CHECK_GE(i, 0); - CHECK_LT(i, node->elements_.size()); - node = node->elements_[i].get(); + CHECK_LT(i, node->children.size()); + node = node->children[i].get(); } - return *node; + return node; } template -const ShapeTree& ShapeTree::Lookup(const ShapeIndex& index) const { - return const_cast*>(this)->Lookup(index); +const internal::ShapeTreeNode* ShapeTree::Lookup( + const ShapeIndex& index) const { + return const_cast(this)->Lookup(index); } /* static */ template -tensorflow::Status ShapeTree::ForEachHelperMutable( - ShapeIndex* index, ShapeTree* shape_tree, - ShapeTree::MutableVisitorFunction func) { - TF_RETURN_IF_ERROR( - func(*index, shape_tree->elements_.empty(), &shape_tree->data_)); - for (int i = 0; i < shape_tree->elements_.size(); ++i) { +Status ShapeTree::ForEachHelper(const StatusVisitorFunction& func, + const Node& node, ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, node.data)); + for (int64 i = 0; i < node.children.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index)); + index->pop_back(); + } + return Status::OK(); +} + +/* static */ +template +Status ShapeTree::ForEachMutableHelper( + const MutableStatusVisitorFunction& func, Node* node, ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, &node->data)); + for (int64 i = 0; i < node->children.size(); ++i) { index->push_back(i); TF_RETURN_IF_ERROR( - ForEachHelperMutable(index, shape_tree->elements_[i].get(), func)); + ForEachMutableHelper(func, node->children[i].get(), index)); index->pop_back(); } - - return tensorflow::Status::OK(); -} - -/* static */ -template -tensorflow::Status ShapeTree::ForEachHelper( - ShapeIndex* index, const ShapeTree& shape_tree, - ShapeTree::VisitorFunction func) { - TF_RETURN_IF_ERROR( - func(*index, shape_tree.elements_.empty(), shape_tree.data_)); - for (int i = 0; i < shape_tree.elements_.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(index, *shape_tree.elements_[i], func)); - index->pop_back(); - } - - return tensorflow::Status::OK(); + return Status::OK(); } template -tensorflow::Status ShapeTree::ForEachElement( - ShapeTree::VisitorFunction func) const { +Status ShapeTree::ForEachElementWithStatus( + const StatusVisitorFunction& func) const { ShapeIndex index; - return ForEachHelper(&index, *this, func); + return ForEachHelper(func, root_, &index); } template -tensorflow::Status ShapeTree::ForEachMutableElement( - ShapeTree::MutableVisitorFunction func) { +Status ShapeTree::ForEachMutableElementWithStatus( + const MutableStatusVisitorFunction& func) { ShapeIndex index; - return ForEachHelperMutable(&index, this, func); + return ForEachMutableHelper(func, &root_, &index); +} + +template +void ShapeTree::ForEachElement(const VisitorFunction& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const T& data) { + func(index, data); + return Status::OK(); + }, + root_, &index) + .IgnoreError(); +} + +template +void ShapeTree::ForEachMutableElement(const MutableVisitorFunction& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, T* data) { + func(index, data); + return Status::OK(); + }, + &root_, &index) + .IgnoreError(); +} + +template +void ShapeTree::CopySubtreeFrom(const ShapeTree& other, + const ShapeIndex& source_base_index, + const ShapeIndex& target_base_index) { + CHECK(ShapeUtil::Compatible( + ShapeUtil::GetSubshape(shape(), target_base_index), + ShapeUtil::GetSubshape(other.shape(), source_base_index))); + ForEachMutableElement([this, &other, &source_base_index, &target_base_index]( + const ShapeIndex& index, T* data) { + // Copy the data element only if index is in the + // subtree rooted at target_base_index. + for (int i = 0; i < target_base_index.size(); ++i) { + if (i >= index.size() || index[i] != target_base_index[i]) { + return; + } + } + // Construct source element index to copy from. + ShapeIndex source_index = source_base_index; + for (int i = target_base_index.size(); i < index.size(); ++i) { + source_index.push_back(index[i]); + } + *data = other.element(source_index); + }); +} + +template +bool ShapeTree::operator==(const ShapeTree& other) const { + bool equal = true; + ForEachElement( + [this, &other, &equal](const ShapeIndex& index, const T& data) { + if (data != other.element(index)) { + equal = false; + } + }); + return equal; } } // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index d37f536b755..afc3a2b2a34 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -35,6 +35,9 @@ class ShapeTreeTest : public ::testing::Test { array_shape_})}); } + void TestShapeConstructor(const Shape& shape, int expected_num_nodes); + void TestInitValueConstructor(const Shape& shape, int expected_num_nodes); + // An array shape (non-tuple). Shape array_shape_; @@ -45,6 +48,73 @@ class ShapeTreeTest : public ::testing::Test { Shape nested_tuple_shape_; }; +TEST_F(ShapeTreeTest, DefaultConstructor) { + ShapeTree int_tree; + EXPECT_TRUE(ShapeUtil::IsNil(int_tree.shape())); + + ShapeTree bool_tree; + EXPECT_TRUE(ShapeUtil::IsNil(bool_tree.shape())); +} + +void ShapeTreeTest::TestShapeConstructor(const Shape& shape, + int expected_num_nodes) { + ShapeTree int_tree(shape); + int num_nodes = 0; + int_tree.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) { + EXPECT_EQ(0, data); + ++num_nodes; + }); + EXPECT_EQ(expected_num_nodes, num_nodes); + + ShapeTree bool_tree(shape); + num_nodes = 0; + bool_tree.ForEachElement( + [&num_nodes](const ShapeIndex& /*index*/, bool data) { + EXPECT_EQ(false, data); + ++num_nodes; + }); + EXPECT_EQ(expected_num_nodes, num_nodes); +} + +TEST_F(ShapeTreeTest, ShapeConstructor) { + TestShapeConstructor(array_shape_, 1); + TestShapeConstructor(tuple_shape_, 4); + TestShapeConstructor(nested_tuple_shape_, 10); +} + +void ShapeTreeTest::TestInitValueConstructor(const Shape& shape, + int expected_num_nodes) { + ShapeTree tree(shape, 42); + int num_nodes = 0; + tree.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) { + EXPECT_EQ(42, data); + ++num_nodes; + }); + EXPECT_EQ(expected_num_nodes, num_nodes); + + num_nodes = 0; + tree.ForEachMutableElement( + [&num_nodes](const ShapeIndex& /*index*/, int* data) { + EXPECT_EQ(42, *data); + *data = num_nodes; + ++num_nodes; + }); + EXPECT_EQ(expected_num_nodes, num_nodes); + + num_nodes = 0; + tree.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) { + EXPECT_EQ(num_nodes, data); + ++num_nodes; + }); + EXPECT_EQ(expected_num_nodes, num_nodes); +} + +TEST_F(ShapeTreeTest, InitValueConstructor) { + TestInitValueConstructor(array_shape_, 1); + TestInitValueConstructor(tuple_shape_, 4); + TestInitValueConstructor(nested_tuple_shape_, 10); +} + TEST_F(ShapeTreeTest, ArrayShape) { ShapeTree shape_tree{array_shape_}; *shape_tree.mutable_element({}) = 42; @@ -57,6 +127,15 @@ TEST_F(ShapeTreeTest, ArrayShape) { // Test the copy constructor. ShapeTree copy{shape_tree}; EXPECT_EQ(123, copy.element({})); + + // Mutate the copy, and ensure the original doesn't change. + *copy.mutable_element({}) = 99; + EXPECT_EQ(99, copy.element({})); + EXPECT_EQ(123, shape_tree.element({})); + + // Test the assignment operator. + copy = shape_tree; + EXPECT_EQ(123, copy.element({})); } TEST_F(ShapeTreeTest, TupleShape) { @@ -74,11 +153,8 @@ TEST_F(ShapeTreeTest, TupleShape) { // Sum all elements in the shape. int sum = 0; - TF_CHECK_OK(shape_tree.ForEachElement( - [&sum](const ShapeIndex& /*index*/, bool /*is_leaf*/, int data) { - sum += data; - return tensorflow::Status::OK(); - })); + shape_tree.ForEachElement( + [&sum](const ShapeIndex& /*index*/, int data) { sum += data; }); EXPECT_EQ(66, sum); // Test the copy constructor. @@ -89,15 +165,23 @@ TEST_F(ShapeTreeTest, TupleShape) { EXPECT_EQ(-100, copy.element({2})); // Write zero to all data elements. - TF_CHECK_OK(shape_tree.ForEachMutableElement( - [&sum](const ShapeIndex& /*index*/, bool /*is_leaf*/, int* data) { - *data = 0; - return tensorflow::Status::OK(); - })); + shape_tree.ForEachMutableElement( + [&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; }); EXPECT_EQ(0, shape_tree.element({})); EXPECT_EQ(0, shape_tree.element({0})); EXPECT_EQ(0, shape_tree.element({1})); EXPECT_EQ(0, shape_tree.element({2})); + EXPECT_EQ(1, copy.element({})); + EXPECT_EQ(42, copy.element({0})); + EXPECT_EQ(123, copy.element({1})); + EXPECT_EQ(-100, copy.element({2})); + + // Test the assignment operator. + copy = shape_tree; + EXPECT_EQ(0, copy.element({})); + EXPECT_EQ(0, copy.element({0})); + EXPECT_EQ(0, copy.element({1})); + EXPECT_EQ(0, copy.element({2})); } TEST_F(ShapeTreeTest, NestedTupleShape) { @@ -116,6 +200,23 @@ TEST_F(ShapeTreeTest, NestedTupleShape) { EXPECT_EQ(42, copy.element({0})); EXPECT_EQ(123, copy.element({1, 1})); EXPECT_EQ(-100, copy.element({2, 0, 1})); + + // Mutate the copy, and ensure the original doesn't change. + *copy.mutable_element({0}) = 1; + *copy.mutable_element({1, 1}) = 2; + *copy.mutable_element({2, 0, 1}) = 3; + EXPECT_EQ(1, copy.element({0})); + EXPECT_EQ(2, copy.element({1, 1})); + EXPECT_EQ(3, copy.element({2, 0, 1})); + EXPECT_EQ(42, shape_tree.element({0})); + EXPECT_EQ(123, shape_tree.element({1, 1})); + EXPECT_EQ(-100, shape_tree.element({2, 0, 1})); + + // Test the assignment operator. + copy = shape_tree; + EXPECT_EQ(42, copy.element({0})); + EXPECT_EQ(123, copy.element({1, 1})); + EXPECT_EQ(-100, copy.element({2, 0, 1})); } TEST_F(ShapeTreeTest, InvalidIndexingTuple) { @@ -130,5 +231,139 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { EXPECT_DEATH(shape_tree.element({0, 0}), ""); } +TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { + ShapeTree> shape_tree{tuple_shape_}; + EXPECT_EQ(shape_tree.element({2}).get(), nullptr); + *shape_tree.mutable_element({2}) = MakeUnique(42); + EXPECT_EQ(*shape_tree.element({2}), 42); +} + +TEST_F(ShapeTreeTest, CopySubtreeFromArrayShape) { + // Test CopySubtreeFrom method for a single value copied between array-shaped + // ShapeTrees. + ShapeTree source(array_shape_); + *source.mutable_element(/*index=*/{}) = 42; + ShapeTree destination(array_shape_, 123); + + EXPECT_EQ(destination.element(/*index=*/{}), 123); + destination.CopySubtreeFrom(source, /*source_base_index=*/{}, + /*target_base_index=*/{}); + EXPECT_EQ(destination.element(/*index=*/{}), 42); +} + +TEST_F(ShapeTreeTest, FullCopySubtreeFromTupleShape) { + // Test CopySubtreeFrom method for a copy of all elements from one + // tuple-shaped ShapeTree to another. + ShapeTree source(tuple_shape_); + *source.mutable_element(/*index=*/{}) = 10; + *source.mutable_element(/*index=*/{0}) = 11; + *source.mutable_element(/*index=*/{1}) = 12; + *source.mutable_element(/*index=*/{2}) = 13; + + ShapeTree destination(tuple_shape_, 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{}, + /*target_base_index=*/{}); + EXPECT_EQ(destination.element(/*index=*/{}), 10); + EXPECT_EQ(destination.element(/*index=*/{0}), 11); + EXPECT_EQ(destination.element(/*index=*/{1}), 12); + EXPECT_EQ(destination.element(/*index=*/{2}), 13); +} + +TEST_F(ShapeTreeTest, SingleElementCopySubtreeFromTupleShape) { + // Test CopySubtreeFrom method for a copy of a single element from one + // tuple-shaped ShapeTree to another. + ShapeTree source(tuple_shape_); + *source.mutable_element(/*index=*/{}) = 10; + *source.mutable_element(/*index=*/{0}) = 11; + *source.mutable_element(/*index=*/{1}) = 12; + *source.mutable_element(/*index=*/{2}) = 13; + + ShapeTree destination(tuple_shape_, 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{0}, + /*target_base_index=*/{1}); + EXPECT_EQ(destination.element(/*index=*/{}), 0); + EXPECT_EQ(destination.element(/*index=*/{0}), 0); + EXPECT_EQ(destination.element(/*index=*/{1}), 11); + EXPECT_EQ(destination.element(/*index=*/{2}), 0); +} + +TEST_F(ShapeTreeTest, CopySubtreeIntoNestedShape) { + // Test CopySubtreeFrom method for a copy of a tuple-shaped ShapeTree into a + // nested-tuple-shaped ShapeTree. + ShapeTree source( + ShapeUtil::MakeTupleShape({array_shape_, array_shape_})); + *source.mutable_element(/*index=*/{}) = 10; + *source.mutable_element(/*index=*/{0}) = 11; + *source.mutable_element(/*index=*/{1}) = 12; + + ShapeTree destination(nested_tuple_shape_, 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{}, + /*target_base_index=*/{2, 0}); + + EXPECT_EQ(destination.element(/*index=*/{}), 0); + EXPECT_EQ(destination.element(/*index=*/{0}), 0); + EXPECT_EQ(destination.element(/*index=*/{1}), 0); + EXPECT_EQ(destination.element(/*index=*/{1, 0}), 0); + EXPECT_EQ(destination.element(/*index=*/{1, 1}), 0); + EXPECT_EQ(destination.element(/*index=*/{2}), 0); + EXPECT_EQ(destination.element(/*index=*/{2, 0}), 10); + EXPECT_EQ(destination.element(/*index=*/{2, 0, 0}), 11); + EXPECT_EQ(destination.element(/*index=*/{2, 0, 1}), 12); + EXPECT_EQ(destination.element(/*index=*/{2, 1}), 0); +} + +TEST_F(ShapeTreeTest, CopySubtreeFromNestedShape) { + // Test CopySubtreeFrom method for a copy from a nested-tuple-shape. + ShapeTree source(nested_tuple_shape_, 42); + *source.mutable_element(/*index=*/{1}) = 10; + *source.mutable_element(/*index=*/{1, 0}) = 11; + *source.mutable_element(/*index=*/{1, 1}) = 12; + + ShapeTree destination( + ShapeUtil::MakeTupleShape({array_shape_, array_shape_}), 0); + + destination.CopySubtreeFrom(source, /*source_base_index=*/{1}, + /*target_base_index=*/{}); + + EXPECT_EQ(destination.element(/*index=*/{}), 10); + EXPECT_EQ(destination.element(/*index=*/{0}), 11); + EXPECT_EQ(destination.element(/*index=*/{1}), 12); +} + +TEST_F(ShapeTreeTest, OperatorEquals) { + { + ShapeTree a(array_shape_, 123); + ShapeTree b(array_shape_, 42); + ShapeTree c(array_shape_, 42); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + EXPECT_TRUE(b == c); + } + { + ShapeTree a(tuple_shape_); + *a.mutable_element(/*index=*/{}) = 10; + *a.mutable_element(/*index=*/{0}) = 11; + *a.mutable_element(/*index=*/{1}) = 12; + + ShapeTree b(tuple_shape_); + *b.mutable_element(/*index=*/{}) = 10; + *b.mutable_element(/*index=*/{0}) = 42; + *b.mutable_element(/*index=*/{1}) = 11; + + ShapeTree c(tuple_shape_); + *c.mutable_element(/*index=*/{}) = 10; + *c.mutable_element(/*index=*/{0}) = 42; + *c.mutable_element(/*index=*/{1}) = 11; + + EXPECT_FALSE(a == b); + EXPECT_TRUE(a != b); + EXPECT_TRUE(b == c); + EXPECT_FALSE(b != c); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 4acdd71d173..ee49a9ae5f5 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "tensorflow/compiler/xla/index_util.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -37,6 +39,16 @@ limitations under the License. namespace xla { +string ShapeIndex::ToString() const { + return tensorflow::strings::StrCat( + "{", tensorflow::str_util::Join(indices_, ","), "}"); +} + +std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { + out << shape_index.ToString(); + return out; +} + namespace { // Recursive helper for comparing the equality of two shapes. Returns true if @@ -44,18 +56,11 @@ namespace { // match. bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (ShapeUtil::IsTuple(lhs)) { - if (!ShapeUtil::IsTuple(rhs)) { - VLOG(3) << "CompareShapes: lhs is a tuple, rhs not a tuple"; - return false; - } - - if (!ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts); - })) { - VLOG(3) << "CompareShapes: tuples on lhs and rhs not equal"; - return false; - } + return ShapeUtil::IsTuple(rhs) && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + [=](const Shape& l, const Shape& r) { + return CompareShapes(l, r, compare_layouts); + }); } // Explicitly compare the fields rather than using MessageDifferencer because // we want empty layouts to be treated identically to missing layouts. @@ -117,7 +122,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { for (const auto& shape : parameters) { *program_shape.add_parameters() = shape; } - *program_shape.mutable_result() = result; + *program_shape.mutable_result() = std::move(result); return program_shape; } @@ -197,7 +202,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { - shape->mutable_layout()->add_minor_to_major(ShapeUtil::Rank(*shape)); + shape->mutable_layout()->add_minor_to_major(Rank(*shape)); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); } @@ -290,7 +295,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { std::vector new_elements(tuple.tuple_shapes().begin() + start, tuple.tuple_shapes().begin() + limit); - return ShapeUtil::MakeTupleShape(new_elements); + return MakeTupleShape(new_elements); } /* static */ bool ShapeUtil::IsOpaque(const Shape& shape) { @@ -304,7 +309,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (shape.element_type() != element_type) { return false; } - if (shape.dimensions_size() != ShapeUtil::Rank(shape)) { + if (shape.dimensions_size() != Rank(shape)) { return false; } int64 i = 0; @@ -318,7 +323,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK_EQ(shape.dimensions_size(), ShapeUtil::Rank(shape)); + CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, std::multiplies()); @@ -329,7 +334,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { - return shape.element_type() == F32 && ShapeUtil::Rank(shape) == 0; + return shape.element_type() == F32 && Rank(shape) == 0; } /* static */ string ShapeUtil::HumanString(const Shape& shape) { @@ -427,13 +432,12 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } Shape result; if (layout_string.empty()) { - result = ShapeUtil::MakeShape(primitive_type, dimensions); + result = MakeShape(primitive_type, dimensions); } else { TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); TF_RET_CHECK(dimensions.size() == min2maj.size()); - result = - ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); + result = MakeShapeWithLayout(primitive_type, dimensions, min2maj); } TF_DCHECK_OK(ValidateShape(result)); return result; @@ -463,7 +467,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape, int64 dimension_number) { if (dimension_number < 0) { - dimension_number += ShapeUtil::Rank(shape); + dimension_number += Rank(shape); } CHECK_GE(dimension_number, 0); return dimension_number; @@ -515,7 +519,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } int64 allocated_element_count; if (shape.layout().padded_dimensions_size() > 0) { - CHECK_EQ(ShapeUtil::Rank(shape), shape.layout().padded_dimensions_size()); + CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size()); allocated_element_count = 1; for (int64 dimension_size : shape.layout().padded_dimensions()) { allocated_element_count *= dimension_size; @@ -531,9 +535,9 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { const Shape& shape) { if (shape.element_type() == TUPLE) { // Tuple shape. - if (ShapeUtil::Rank(shape) != 0) { + if (Rank(shape) != 0) { return InvalidArgument("tuples must be rank-0; got rank %lld", - ShapeUtil::Rank(shape)); + Rank(shape)); } if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); @@ -553,13 +557,13 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return InvalidArgument("shape has invalid element type: %s", shape.ShortDebugString().c_str()); } - if (ShapeUtil::Rank(shape) != shape.dimensions_size()) { + if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( "shape's rank is mismatched with dimension count; rank=%lld " "dimensions_size=%d", - ShapeUtil::Rank(shape), shape.dimensions_size()); + Rank(shape), shape.dimensions_size()); } - for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + for (int64 i = 0; i < Rank(shape); ++i) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( @@ -614,6 +618,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return return_shape; } +/* static */ +bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { + return !IsTuple(GetSubshape(shape, index)); +} + /* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { std::vector dimension_sizes; std::vector degenerate_dimensions; @@ -672,7 +681,7 @@ namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in // DFS pre-order starting with the index. Status ForEachSubshapeHelper(const Shape& shape, - const ShapeUtil::VisitorFunction func, + const ShapeUtil::StatusVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); if (ShapeUtil::IsTuple(shape)) { @@ -689,7 +698,7 @@ Status ForEachSubshapeHelper(const Shape& shape, // Helper for ForEachMutableSubshape which visits the subshapes of the given // shape in DFS pre-order starting with the index. Status ForEachMutableSubshapeHelper( - Shape* shape, const ShapeUtil::MutatingVisitorFunction func, + Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); if (ShapeUtil::IsTuple(*shape)) { @@ -705,14 +714,40 @@ Status ForEachMutableSubshapeHelper( } // namespace -/* static */ Status ShapeUtil::ForEachSubshape(const Shape& shape, - VisitorFunction func) { +/* static */ void ShapeUtil::ForEachSubshape(const Shape& shape, + const VisitorFunction& func) { + ShapeIndex index; + ForEachSubshapeHelper( + shape, + [&func](const Shape& subshape, const ShapeIndex& index) { + func(subshape, index); + return Status::OK(); + }, + &index) + .IgnoreError(); +} + +/* static */ void ShapeUtil::ForEachMutableSubshape( + Shape* shape, const MutatingVisitorFunction& func) { + ShapeIndex index; + ForEachMutableSubshapeHelper( + shape, + [&func](Shape* subshape, const ShapeIndex& index) { + func(subshape, index); + return Status::OK(); + }, + &index) + .IgnoreError(); +} + +/* static */ Status ShapeUtil::ForEachSubshapeWithStatus( + const Shape& shape, const StatusVisitorFunction& func) { ShapeIndex index; return ForEachSubshapeHelper(shape, func, &index); } -/* static */ Status ShapeUtil::ForEachMutableSubshape( - Shape* shape, MutatingVisitorFunction func) { +/* static */ Status ShapeUtil::ForEachMutableSubshapeWithStatus( + Shape* shape, const MutatingStatusVisitorFunction& func) { ShapeIndex index; return ForEachMutableSubshapeHelper(shape, func, &index); } @@ -725,9 +760,17 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { - new_shape.mutable_layout()->clear_minor_to_major(); + Layout* new_layout = new_shape.mutable_layout(); + new_layout->clear_minor_to_major(); for (auto index : Permute(permutation, shape.layout().minor_to_major())) { - new_shape.mutable_layout()->add_minor_to_major(index); + new_layout->add_minor_to_major(index); + } + if (shape.layout().padded_dimensions_size() > 0) { + new_layout->clear_padded_dimensions(); + for (auto dim : + Permute(permutation, shape.layout().padded_dimensions())) { + new_layout->add_padded_dimensions(dim); + } } } return new_shape; @@ -744,27 +787,28 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, // and unmodified_dim_pair have size >1. Otherwise, returns true and appends // the degerenate input/output dimensions in the gap to // deleted_indices/inserted_indices respectively. - auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices, - &inserted_indices]( - std::pair prior_unmodified_dim_pair, - std::pair unmodified_dim_pair) { - for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; - modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) { - if (shape_pre.dimensions(modified_input_dim) > 1) { - return false; - } - deleted_indices.push_back(modified_input_dim); - } - for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; - modified_output_dim < unmodified_dim_pair.second; - ++modified_output_dim) { - if (shape_post.dimensions(modified_output_dim) > 1) { - return false; - } - inserted_indices.push_back(modified_output_dim); - } - return true; - }; + auto check_modified_dims = + [&shape_pre, &shape_post, &deleted_indices, &inserted_indices]( + std::pair prior_unmodified_dim_pair, + std::pair unmodified_dim_pair) { + for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; + modified_input_dim < unmodified_dim_pair.first; + ++modified_input_dim) { + if (shape_pre.dimensions(modified_input_dim) > 1) { + return false; + } + deleted_indices.push_back(modified_input_dim); + } + for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; + modified_output_dim < unmodified_dim_pair.second; + ++modified_output_dim) { + if (shape_post.dimensions(modified_output_dim) > 1) { + return false; + } + inserted_indices.push_back(modified_output_dim); + } + return true; + }; std::vector> unmodified_dims = DimensionsUnmodifiedByReshape(shape_pre, shape_post); @@ -780,8 +824,7 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, auto unmodified_dim_pair = i < unmodified_dims.size() ? unmodified_dims[i] - : std::make_pair(ShapeUtil::Rank(shape_pre), - ShapeUtil::Rank(shape_post)); + : std::make_pair(Rank(shape_pre), Rank(shape_post)); if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) { return nil; } @@ -856,9 +899,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return false; } - CHECK_EQ(ShapeUtil::ElementsIn(input_shape), - ShapeUtil::ElementsIn(output_shape)); - if (ShapeUtil::ElementsIn(input_shape) == 0) { + CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape)); + if (ElementsIn(input_shape) == 0) { return true; } @@ -972,21 +1014,17 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, // as input_shape/output_shape and the dimension-0-major layout. These two // shapes are used for conversion between logical linear indices and // multi-dimensional indices. - Shape input_shape_dim0_major = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); - Shape output_shape_dim0_major = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - output_shape.element_type(), - AsInt64Slice(output_shape.dimensions())); + Shape input_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); + Shape output_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { + for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { if (input_shape.dimensions(input_dim) <= 1) { continue; } - std::vector input_unit_index(ShapeUtil::Rank(input_shape), 0); + std::vector input_unit_index(Rank(input_shape), 0); input_unit_index[input_dim] = 1; int64 logical_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, @@ -1010,6 +1048,140 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, check_input_unit_indices(output_shape, input_shape); } +/* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( + const Shape& input_shape, const Shape& output_shape) { + int64 input_rank = Rank(input_shape); + int64 output_rank = Rank(output_shape); + + // First, calculate an alignment of the dimensions. A consecutive sequence of + // input dimensions and output dimensions belong to the same alignment part if + // the products of their dimension bounds are the same. In the easiest case, + // an alignment part consists of one input dimension and one output dimension + // which both have the same dimension bound. An alignment part specifies which + // dimensions need to be kept together in a physical layout if we want a + // reshape to be a bitcast. The order of the alignment parts is defined by the + // physical layout of the input shape, so when we construct the layout for the + // output shape we just process the alignment parts in this order, and then + // layout the dimensions belonging to each part in descending (major to minor) + // order. + + // Stores the input and output dimension numbers where each alignment part + // starts. + std::vector> alignment; + alignment.push_back({0, 0}); + + // Stores a mapping from the input dimension to the alignment part it belongs + // to. + std::vector dimension_to_alignment_index(input_rank); + int64 input_dimension_product = 1, output_dimension_product = 1; + for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) { + // Check if we have reached the end of an alignment part. + if (input_dimension_product == output_dimension_product && + input_dimension_product > 1) { + alignment.push_back({i, j}); + input_dimension_product = output_dimension_product = 1; + } + if (input_dimension_product < output_dimension_product || + j == output_rank) { + if (i == input_rank) { + return tensorflow::gtl::nullopt; + } + dimension_to_alignment_index[i] = alignment.size() - 1; + input_dimension_product *= input_shape.dimensions(i); + ++i; + } else { + output_dimension_product *= output_shape.dimensions(j); + ++j; + } + } + if (input_dimension_product != output_dimension_product) { + return tensorflow::gtl::nullopt; + } + // We also need to store an end element so that we know where the last + // alignment part ends. + alignment.push_back({input_rank, output_rank}); + + // Now check if the physical layout can potentially be aligned to the output + // shape by changing the physical layout of the output shape. We need to check + // that all dimension numbers that belong to the same alignment part appear + // consecutively, and are in descending order. However we can ignore any + // trivial dimension bounds of 1, because they can be placed anywhere. + auto input_dimension_numbers = input_shape.layout().minor_to_major(); + std::vector output_layout; + output_layout.reserve(output_rank); + for (int64 i = 0; i < input_rank;) { + int64 current_dimension_number = input_dimension_numbers[i]; + + // Skip trivial dimensions with a bound of 1. + if (input_shape.dimensions(current_dimension_number) == 1) { + ++i; + continue; + } + + // Calculate the number of non-trivial dimension bounds in the input shape + // belonging to the current alignment part. + const int64 current_alignment_index = + dimension_to_alignment_index[current_dimension_number]; + // Because of the special end element that we added, we can be sure that + // 'current_alignment_index' is < alignment.size() - 1. + CHECK_LT(current_alignment_index, alignment.size() - 1); + int64 num_non_trivial_dimensions_in_alignment_part = 0; + for (int64 j = alignment[current_alignment_index].first; + j < alignment[current_alignment_index + 1].first; ++j) { + if (input_shape.dimensions(j) != 1) { + ++num_non_trivial_dimensions_in_alignment_part; + } + } + + // Check that the following 'num_non_trivial_dimensions_in_alignment_part' + // dimension numbers (ignoring dimension numbers with dimension bound 1) are + // in descending order and belong to the current alignment part. + for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; + ++i, ++j) { + if (i == input_rank) { + return tensorflow::gtl::nullopt; + } + // Skip trivial dimensions with a bound of 1. + if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { + --j; + continue; + } + // If the current dimension number belongs to a different alignment part, + // or the dimension numbers are not in descending order, we can return + // early. + if (dimension_to_alignment_index[input_dimension_numbers[i]] != + current_alignment_index || + input_dimension_numbers[i] > current_dimension_number) { + return tensorflow::gtl::nullopt; + } + current_dimension_number = input_dimension_numbers[i]; + } + + // The output dimension numbers that belong to the current alignment part + // need to appear in the same descending order as in the input. Again, we + // can skip dimensions with a bound of 1. + for (int64 j = alignment[current_alignment_index + 1].second - 1; + j >= alignment[current_alignment_index].second; --j) { + if (output_shape.dimensions(j) != 1) { + output_layout.push_back(j); + } + } + } + // Now add all the dimensions with dimension bound 1 at the end of + // 'output_layout'. + for (int64 i = 0; i < output_rank; ++i) { + if (output_shape.dimensions(i) == 1) { + output_layout.push_back(i); + } + } + CHECK_EQ(output_layout.size(), output_rank); + Shape output_shape_with_layout = MakeShapeWithLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), + output_layout); + CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout)); + return output_shape_with_layout; +} + /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); @@ -1044,4 +1216,34 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } +/* static */ void ShapeUtil::ForEachIndex( + const Shape& shape, tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const IndexVisitorFunction& visitor_function) { + if (ShapeUtil::HasZeroElements(shape)) { + return; + } + DCHECK_EQ(Rank(shape), base.size()); + DCHECK_EQ(incr.size(), base.size()); + DCHECK_EQ(count.size(), base.size()); + const Layout& layout = shape.layout(); + int64 rank = layout.minor_to_major_size(); + // Allows handling R0 arrays, such that the visitor function will be called + // once with the proper empty indexes. + int64 n = -1; + std::vector indexes(base.begin(), base.end()); + while (n < rank && visitor_function(indexes)) { + // Increments dimensions in minor to major order. + for (n = 0; n < rank; ++n) { + int64 dim = layout.minor_to_major(n); + indexes[dim] += incr[dim]; + if (indexes[dim] < base[dim] + count[dim]) { + break; + } + indexes[dim] = base[dim]; + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 290d993b50a..853be6b4cb8 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -72,11 +73,18 @@ class ShapeIndex { return indices_ == other.indices_; } bool operator!=(const ShapeIndex& other) const { return !(*this == other); } + bool operator<(const ShapeIndex& other) const { + return indices_ < other.indices_; + } + + string ToString() const; private: std::vector indices_; }; +std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); + // Namespaced collection of (static) shape utilities. // // These are all effectively convenience functions for testing/tweaking proto @@ -220,7 +228,7 @@ class ShapeUtil { // Validates that the provided shape satisfies invariants. static Status ValidateShape(const Shape& shape); - // Validates the the provided shape satisfies invariants, except those that + // Validates the provided shape satisfies invariants, except those that // pertain to layout. // // Layout is optional for client-provided shapes, so that the compiler may @@ -287,18 +295,31 @@ class ShapeUtil { static const Shape& GetSubshape(const Shape& shape, const ShapeIndex& index); static Shape* GetMutableSubshape(Shape* shape, const ShapeIndex& index); - // Calls the given visitor function for each subshape of the given shape. - // Returns early if an error status is returned. Subshapes are visited in DFS - // pre-order starting with the entire shape (index {}). - using VisitorFunction = std::function; - static Status ForEachSubshape(const Shape& shape, VisitorFunction func); + // Returns whether the given index in the given shape is a leaf element of the + // shape. + static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); - // Mutating variant of ForEachSubshape. + // Calls the given visitor function for each subshape of the given shape. + // Subshapes are visited in DFS pre-order starting with the entire shape + // (index {}). + using VisitorFunction = std::function; + static void ForEachSubshape(const Shape& shape, const VisitorFunction& func); using MutatingVisitorFunction = + std::function; + static void ForEachMutableSubshape(Shape* shape, + const MutatingVisitorFunction& func); + + // Variants of ForEach(Mutable)Subshape which propagate Status from the + // visitor function. + using StatusVisitorFunction = std::function; + static Status ForEachSubshapeWithStatus(const Shape& shape, + const StatusVisitorFunction& func); + using MutatingStatusVisitorFunction = std::function; - static Status ForEachMutableSubshape(Shape* shape, - MutatingVisitorFunction func); + static Status ForEachMutableSubshapeWithStatus( + Shape* shape, const MutatingStatusVisitorFunction& func); // Removes all degenerate dimensions (size one) from the given shape. The // stripped minor_to_major preserves the relative ordering of non-degenerate @@ -370,6 +391,15 @@ class ShapeUtil { static bool ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape); + // Find a physical layout for 'output_shape' such that + // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns + // true (where 'output_shape_with_layout' is 'output_shape' with the found + // layout). The layout of 'input_shape' is kept fixed. Returns + // 'output_shape_with_layout' if such a layout can be found, and an error + // otherwise. + static tensorflow::gtl::optional AlignLayouts( + const Shape& input_shape, const Shape& output_shape); + // Returns a shape with the given dimension deleted. // For example: // • `DeleteDimension(1, T[m, n, k]) = T[m, k]` @@ -383,6 +413,19 @@ class ShapeUtil { static Shape FilterDimensions(const std::function& p, Shape shape); + // Iterates through all the shape indexes, in minor to major order, starting + // from the base indexes, incrementing by the incr steps, up to count + // (index[i] < base[i] + count[i]), and calls the visitor_function with the + // current index. + // The visitor_function visitor function should return true if it wants to + // continue, or false otherwise. + using IndexVisitorFunction = std::function&)>; + static void ForEachIndex(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const IndexVisitorFunction& visitor_function); + private: // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 9e6b243611b..69ef6175ccd 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,14 +16,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/test.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { +using ::testing::ElementsAre; + TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) { Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1)); @@ -319,6 +322,30 @@ TEST(ShapeUtilTest, GetSubshape) { ShapeUtil::GetSubshape(nested_tuple_shape, {2, 0}))); } +TEST(ShapeUtilTest, IsLeafIndex) { + // Test array shape. + Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123}); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(array_shape, {})); + + // Test tuple shape. + Shape tuple_shape = ShapeUtil::MakeTupleShape({array_shape, array_shape}); + EXPECT_FALSE(ShapeUtil::IsLeafIndex(tuple_shape, {})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {0})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {1})); + + // Test nested tuple shape. + Shape nested_tuple_shape = ShapeUtil::MakeTupleShape( + {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({array_shape, array_shape}), + array_shape})}); + EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {0})); + EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 0})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1})); +} + TEST(ShapeUtilTest, HumanString) { Shape opaque = ShapeUtil::MakeOpaqueShape(); Shape scalar = ShapeUtil::MakeShape(F32, {}); @@ -377,13 +404,12 @@ TEST(ShapeUtilTest, HumanString) { TEST(ShapeUtilTest, ForEachSubshapeArray) { const Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); int calls = 0; - EXPECT_IS_OK(ShapeUtil::ForEachSubshape( + ShapeUtil::ForEachSubshape( shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) { EXPECT_EQ(&shape, &subshape); EXPECT_TRUE(index.empty()); ++calls; - return tensorflow::Status::OK(); - })); + }); EXPECT_EQ(1, calls); } @@ -393,7 +419,7 @@ TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}), ShapeUtil::MakeShape(PRED, {33})})}); int calls = 0; - EXPECT_IS_OK(ShapeUtil::ForEachSubshape( + ShapeUtil::ForEachSubshape( shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) { EXPECT_TRUE( ShapeUtil::Equal(subshape, ShapeUtil::GetSubshape(shape, index))); @@ -405,8 +431,7 @@ TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) { EXPECT_EQ(33, ShapeUtil::ElementsIn(subshape)); } ++calls; - return tensorflow::Status::OK(); - })); + }); EXPECT_EQ(5, calls); } @@ -416,7 +441,7 @@ TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}), ShapeUtil::MakeShape(PRED, {33})})}); int calls = 0; - EXPECT_IS_OK(ShapeUtil::ForEachMutableSubshape( + ShapeUtil::ForEachMutableSubshape( &shape, [&calls, &shape](const Shape* subshape, const ShapeIndex& index) { // Pointer values should be equal EXPECT_EQ(subshape, ShapeUtil::GetMutableSubshape(&shape, index)); @@ -428,8 +453,7 @@ TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) { EXPECT_EQ(33, ShapeUtil::ElementsIn(*subshape)); } ++calls; - return tensorflow::Status::OK(); - })); + }); EXPECT_EQ(5, calls); } @@ -443,24 +467,52 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); } +TEST(ShapeUtilTest, ForEachIndex) { + struct ShapeDimensionAndNumberInvocations { + std::vector dimensions; + int invocations; + } test_data[] = { + {{}, 1}, {{0}, 0}, {{16}, 16}, {{3, 0}, 0}, + {{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610}, + }; + + for (const auto& data : test_data) { + Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); + // Increments at every invocation. + int invocations = 0; + auto increment_func = [&invocations](const std::vector& indexes) { + invocations++; + return true; + }; + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + + ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step, + increment_func); + + EXPECT_EQ(invocations, data.invocations); + } +} + TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one. - EXPECT_EQ(3, - ShapeUtil::DimensionsUnmodifiedByReshape( - ShapeUtil::MakeShape(S32, {1, 1, 1, 1}), - ShapeUtil::MakeShape(S32, {1, 1, 1})) - .size()); + EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1, 1}), + ShapeUtil::MakeShape(S32, {1, 1, 1})), + ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1), + std::make_pair(2, 2))); } TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1_to_1x1x1x1) { // All input dimensions should be unmodified. One of the output dimensions is // modified because the output rank is larger by one. - EXPECT_EQ(3, - ShapeUtil::DimensionsUnmodifiedByReshape( - ShapeUtil::MakeShape(S32, {1, 1, 1}), - ShapeUtil::MakeShape(S32, {1, 1, 1, 1})) - .size()); + EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1}), + ShapeUtil::MakeShape(S32, {1, 1, 1, 1})), + ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1), + std::make_pair(2, 2))); } TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) { @@ -468,11 +520,10 @@ TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) { // 4, 1, 3, 5, 6, 7 // | // 2, 6, 1, 5, 1, 42 - EXPECT_TRUE( - ContainersEqual(ShapeUtil::DimensionsUnmodifiedByReshape( - ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}), - ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})), - std::vector>({{3, 3}}))); + EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}), + ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})), + ElementsAre(std::make_pair(3, 3))); } TEST(ShapeUtilTest, ReshapeIsBitcast_3x4_6x2) { @@ -521,5 +572,58 @@ TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensions) { + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {3, 2, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(4, 3, 2, 1, 0, 5)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); + + aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {3, 2, 4, 35, 11})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(3, 2, 1, 0, 4)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) { + Shape input = + ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 3, 8, 1, 5, 7, 1, 11, 1, 1}, + {5, 0, 4, 2, 1, 3, 6, 7, 9, 8}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +// A test case where the consecutive elements of the input shape belonging to +// the same layout part are not in descending order. +TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensionsWrongInputLayout) { + // Same physical layout as in AlignLayoutsWithoutTrivialDimensions, except + // that the first two dimension numbers are exchanged. + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {2, 3, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); + EXPECT_FALSE(aligned_shape); +} + +// A test case where the physical layout of the input shape does not place all +// dimensions that belong to the same alignment part consecutively. +TEST(AlignmentTest, + AlignLayoutsWithoutTrivialDimensionsNonConsecutiveAlignmentPart) { + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {3, 2, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 5, 77})); + EXPECT_FALSE(aligned_shape); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h index f3b561fada3..4eb3bf37664 100644 --- a/tensorflow/compiler/xla/status.h +++ b/tensorflow/compiler/xla/status.h @@ -21,25 +21,7 @@ limitations under the License. namespace xla { -#if defined(__clang__) -// Only clang supports warn_unused_result as a type annotation. -class TF_MUST_USE_RESULT Status; -#endif - -// Simple wrapper around tensorflow::Status that has the MUST_USE_RESULT -// annotation above. When tensorflow::Status adopts this annotation, this can -// simply become a "using tensorflow::Status". -class Status : public tensorflow::Status { - public: - static Status OK() { return tensorflow::Status::OK(); } - - // Note: implicit constructor. - Status(tensorflow::Status other) : tensorflow::Status(other) {} - - Status() : tensorflow::Status() {} - Status(tensorflow::error::Code code, tensorflow::StringPiece msg) - : tensorflow::Status(code, msg) {} -}; +using tensorflow::Status; } // namespace xla diff --git a/tensorflow/compiler/xla/status_macros_test.cc b/tensorflow/compiler/xla/status_macros_test.cc index 4e7b9161db5..dead17cdfa1 100644 --- a/tensorflow/compiler/xla/status_macros_test.cc +++ b/tensorflow/compiler/xla/status_macros_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/test.h" namespace xla { @@ -40,15 +40,15 @@ Status RetCheckSuccess() { TEST(StatusMacros, RetCheckFailing) { Status status = RetCheckFail(); EXPECT_EQ(status.code(), tensorflow::error::INTERNAL); - EXPECT_MATCH(status.error_message(), - xla::testing::ContainsRegex("RET_CHECK failure.*2 > 3")); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("RET_CHECK failure.*2 > 3")); } TEST(StatusMacros, RetCheckFailingWithExtraMessage) { Status status = RetCheckFailWithExtraMessage(); EXPECT_EQ(status.code(), tensorflow::error::INTERNAL); - EXPECT_MATCH(status.error_message(), - xla::testing::ContainsRegex("RET_CHECK.*2 > 3 extra message")); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("RET_CHECK.*2 > 3 extra message")); } TEST(StatusMacros, RetCheckSucceeding) { @@ -73,7 +73,7 @@ Status ReturnStatusError() { return (tensorflow::errors::Internal("foobar")); } using StatusReturningFunction = std::function; -StatusOr CallStatusReturningFunction(StatusReturningFunction func) { +StatusOr CallStatusReturningFunction(const StatusReturningFunction& func) { TF_RETURN_IF_ERROR(func()); return 42; } diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index 8046a2216fe..d8cd736238c 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -105,7 +105,6 @@ class StatusOr { // In optimized builds, passing Status::OK here will have the effect // of passing tensorflow::error::INTERNAL as a fallback. StatusOr(Status status); // NOLINT - StatusOr(tensorflow::Status status); // NOLINT // Construct a new StatusOr with the given value. If T is a plain pointer, // value must not be NULL. After calling this constructor, calls to @@ -196,8 +195,6 @@ class StatusOr : public StatusOr { : StatusOr::StatusOr(std::move(value)) {} StatusOr(Status status) // NOLINT : StatusOr::StatusOr(std::move(status)) {} - StatusOr(tensorflow::Status status) // NOLINT - : StatusOr::StatusOr(std::move(status)) {} template StatusOr(StatusOr&& other) // NOLINT : StatusOr::StatusOr(std::move(other)) {} @@ -245,14 +242,6 @@ inline StatusOr::StatusOr(Status status) } } -template -inline StatusOr::StatusOr(tensorflow::Status status) - : status_(status) { - if (status_.ok()) { - status_ = internal::StatusOrHelper::HandleInvalidStatusCtorArg(); - } -} - template inline StatusOr::StatusOr(const T& value) : value_(value) { diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index d98eb279336..f8555113f81 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" namespace xla { @@ -45,7 +45,7 @@ class Base2 { class Derived : public Base1, public Base2 { public: - virtual ~Derived() {} + ~Derived() override {} int evenmorepad; }; @@ -436,17 +436,17 @@ class BenchmarkFactory { } Status ArgumentFactoryFail(T** result) TF_ATTRIBUTE_NOINLINE { - *result = NULL; + *result = nullptr; return Status(tensorflow::error::CANCELLED, ""); } Status ArgumentFactoryFailShortMsg(T** result) TF_ATTRIBUTE_NOINLINE { - *result = NULL; + *result = nullptr; return Status(::tensorflow::error::INTERNAL, ""); } Status ArgumentFactoryFailLongMsg(T** result) TF_ATTRIBUTE_NOINLINE { - *result = NULL; + *result = nullptr; return Status(::tensorflow::error::INTERNAL, "a big string of message junk that will never be read"); } @@ -489,26 +489,30 @@ class BenchmarkType { // Calibrate the amount of time spent just calling DoWork, since each of our // tests will do this, we can subtract this out of benchmark results. -static void BM_CalibrateWorkLoop(int iters) { +void BM_CalibrateWorkLoop(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; BenchmarkType* result = factory.TrivialFactory(); tensorflow::testing::StartTiming(); for (int i = 0; i != iters; ++i) { - if (result != NULL) result->DoWork(); + if (result != nullptr) { + result->DoWork(); + } } } BENCHMARK(BM_CalibrateWorkLoop); // Measure the time taken to call into the factory, return the value, // determine that it is OK, and invoke a trivial function. -static void BM_TrivialFactory(int iters) { +void BM_TrivialFactory(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; tensorflow::testing::StartTiming(); for (int i = 0; i != iters; ++i) { BenchmarkType* result = factory.TrivialFactory(); - if (result != NULL) result->DoWork(); + if (result != nullptr) { + result->DoWork(); + } } } BENCHMARK(BM_TrivialFactory); @@ -516,14 +520,14 @@ BENCHMARK(BM_TrivialFactory); // Measure the time taken to call into the factory, providing an // out-param for the result, evaluating the status result and the // result pointer, and invoking the trivial function. -static void BM_ArgumentFactory(int iters) { +void BM_ArgumentFactory(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; tensorflow::testing::StartTiming(); for (int i = 0; i != iters; ++i) { - BenchmarkType* result = NULL; + BenchmarkType* result = nullptr; Status status = factory.ArgumentFactory(&result); - if (status.ok() && result != NULL) { + if (status.ok() && result != nullptr) { result->DoWork(); } } @@ -532,7 +536,7 @@ BENCHMARK(BM_ArgumentFactory); // Measure the time to use the StatusOr factory, evaluate the result, // and invoke the trivial function. -static void BM_StatusOrFactory(int iters) { +void BM_StatusOrFactory(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; tensorflow::testing::StartTiming(); @@ -548,14 +552,14 @@ BENCHMARK(BM_StatusOrFactory); // Measure the time taken to call into the factory, providing an // out-param for the result, evaluating the status result and the // result pointer, and invoking the trivial function. -static void BM_ArgumentFactoryFail(int iters) { +void BM_ArgumentFactoryFail(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; tensorflow::testing::StartTiming(); for (int i = 0; i != iters; ++i) { - BenchmarkType* result = NULL; + BenchmarkType* result = nullptr; Status status = factory.ArgumentFactoryFail(&result); - if (status.ok() && result != NULL) { + if (status.ok() && result != nullptr) { result->DoWork(); } } @@ -564,7 +568,7 @@ BENCHMARK(BM_ArgumentFactoryFail); // Measure the time to use the StatusOr factory, evaluate the result, // and invoke the trivial function. -static void BM_StatusOrFactoryFail(int iters) { +void BM_StatusOrFactoryFail(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; tensorflow::testing::StartTiming(); @@ -580,14 +584,14 @@ BENCHMARK(BM_StatusOrFactoryFail); // Measure the time taken to call into the factory, providing an // out-param for the result, evaluating the status result and the // result pointer, and invoking the trivial function. -static void BM_ArgumentFactoryFailShortMsg(int iters) { +void BM_ArgumentFactoryFailShortMsg(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; tensorflow::testing::StartTiming(); for (int i = 0; i != iters; ++i) { - BenchmarkType* result = NULL; + BenchmarkType* result = nullptr; Status status = factory.ArgumentFactoryFailShortMsg(&result); - if (status.ok() && result != NULL) { + if (status.ok() && result != nullptr) { result->DoWork(); } } @@ -596,7 +600,7 @@ BENCHMARK(BM_ArgumentFactoryFailShortMsg); // Measure the time to use the StatusOr factory, evaluate the result, // and invoke the trivial function. -static void BM_StatusOrFactoryFailShortMsg(int iters) { +void BM_StatusOrFactoryFailShortMsg(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; tensorflow::testing::StartTiming(); @@ -612,14 +616,14 @@ BENCHMARK(BM_StatusOrFactoryFailShortMsg); // Measure the time taken to call into the factory, providing an // out-param for the result, evaluating the status result and the // result pointer, and invoking the trivial function. -static void BM_ArgumentFactoryFailLongMsg(int iters) { +void BM_ArgumentFactoryFailLongMsg(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; tensorflow::testing::StartTiming(); for (int i = 0; i != iters; ++i) { - BenchmarkType* result = NULL; + BenchmarkType* result = nullptr; Status status = factory.ArgumentFactoryFailLongMsg(&result); - if (status.ok() && result != NULL) { + if (status.ok() && result != nullptr) { result->DoWork(); } } @@ -628,7 +632,7 @@ BENCHMARK(BM_ArgumentFactoryFailLongMsg); // Measure the time to use the StatusOr factory, evaluate the result, // and invoke the trivial function. -static void BM_StatusOrFactoryFailLongMsg(int iters) { +void BM_StatusOrFactoryFailLongMsg(int iters) { tensorflow::testing::StopTiming(); BenchmarkFactory factory; tensorflow::testing::StartTiming(); diff --git a/tensorflow/compiler/xla/test.h b/tensorflow/compiler/xla/test.h new file mode 100644 index 00000000000..87a8c5f3a52 --- /dev/null +++ b/tensorflow/compiler/xla/test.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPLIER_XLA_TEST_H_ +#define TENSORFLOW_COMPLIER_XLA_TEST_H_ + +// This header includes gmock.h and enables the use of gmock matchers in tests +// in third_party/tensorflow/compiler/xla. +// +// Test including this header can use the macros EXPECT_THAT(...) and +// ASSERT_THAT(...) in combination with gmock matchers. +// Example: +// std::vector vec = Foo(); +// EXPECT_THAT(vec, ::testing::ElementsAre(1,2,3)); +// +// For more details on gmock matchers see: +// https://github.com/google/googletest/blob/master/googlemock/docs/CheatSheet.md#matchers +// +// The advantages of using gmock matchers instead of self defined matchers are +// better error messages, more maintainable tests and more test coverage. +// +// Note that while the use of gmock matchers is allowed in the xla project, the +// use of mocks is disallowed in the whole tensorflow project! + +#include "tensorflow/core/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) +#include "testing/base/public/gmock.h" +#else +#include +#include +#endif + +#include "tensorflow/core/platform/test.h" + +#endif // TENSORFLOW_COMPLIER_XLA_TEST_H_ diff --git a/tensorflow/compiler/xla/test_helpers.cc b/tensorflow/compiler/xla/test_helpers.cc deleted file mode 100644 index 02abfdeab80..00000000000 --- a/tensorflow/compiler/xla/test_helpers.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/regexp.h" - -namespace xla { -namespace testing { - -AssertionResult::AssertionResult(const AssertionResult& other) - : success_(other.success_), - message_(other.message_ != nullptr ? new std::string(*other.message_) - : static_cast(nullptr)) { -} - -// Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. -AssertionResult AssertionResult::operator!() const { - AssertionResult negation(!success_); - if (message_ != nullptr) negation << *message_; - return negation; -} - -AssertionResult& AssertionResult::operator=(const AssertionResult& ar) { - success_ = ar.success_; - message_.reset(ar.message_ != nullptr ? new std::string(*ar.message_) - : nullptr); - return *this; -} - -AssertionResult AssertionFailure() { return AssertionResult(false); } - -AssertionResult AssertionSuccess() { return AssertionResult(true); } - -std::function ContainsRegex( - const tensorflow::StringPiece regex) { - return [regex](const tensorflow::StringPiece to_test) { - if (RE2::PartialMatch( - tensorflow::RegexpStringPiece(to_test.data(), to_test.size()), - tensorflow::RegexpStringPiece(regex.data(), regex.size()))) { - return true; - } else { - LOG(ERROR) << "Expected to find " << regex << " in " << to_test; - return false; - } - }; -} - -std::function HasSubstr( - const tensorflow::StringPiece part) { - return [part](const tensorflow::StringPiece whole) { - return whole.contains(part); - }; -} - -} // namespace testing -} // namespace xla diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index f923d9f36c8..634cdb5aa29 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -39,286 +39,6 @@ class Literal; namespace testing { -class AssertionResult { - public: - explicit AssertionResult(bool success) : success_(success) {} - - // Returns true iff the assertion succeeded. - operator bool() const { return success_; } // NOLINT - - // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. - AssertionResult operator!() const; - - // Returns the text streamed into this AssertionResult. Test assertions - // use it when they fail (i.e., the predicate's outcome doesn't match the - // assertion's expectation). When nothing has been streamed into the - // object, returns an empty string. - const char* message() const { - return message_ != nullptr ? message_->c_str() : ""; - } - - // Streams a custom failure message into this object. - template - AssertionResult& operator<<(const T& value) { - AppendMessage(::testing::Message() << value); - return *this; - } - - // Allows streaming basic output manipulators such as endl or flush into - // this object. - AssertionResult& operator<<( - std::ostream& (*basic_manipulator)(std::ostream& stream)) { - AppendMessage(::testing::Message() << basic_manipulator); - return *this; - } - - // Copy operator. - AssertionResult(const AssertionResult& ar); - - // Assignment operator. - AssertionResult& operator=(const AssertionResult&); - - private: - // Appends the contents of message to message_. - void AppendMessage(const ::testing::Message& a_message) { - if (message_ == nullptr) message_.reset(new std::string); - message_->append(a_message.GetString().c_str()); - } - - bool success_ = false; - - // Stores the message describing the condition in case the - // expectation construct is not satisfied with the predicate's - // outcome. Referenced via a pointer to avoid taking too much stack - // frame space with test assertions. - std::unique_ptr message_; -}; - -AssertionResult AssertionFailure(); - -AssertionResult AssertionSuccess(); - -std::function ContainsRegex( - const tensorflow::StringPiece regex); - -std::function HasSubstr( - const tensorflow::StringPiece part); - -// Matcher for a vector of same-type values for which operator= is -// defined. -template -std::function& actual)> VectorMatcher( - const std::vector& expected) { - return [expected](const std::vector& actual) -> AssertionResult { - int len = expected.size(); - if (actual.size() != len) { - return AssertionFailure() << "Actual values len of " << actual.size() - << " != expected.size " << len; - } - for (int i = 0; i < len; ++i) { - if (actual[i] != expected[i]) { - return AssertionFailure() << "Element " << i << " actual " << actual[i] - << " != " << expected[i]; - } - } - return AssertionSuccess(); - }; -} - -// Approximate matcher for a vector of floats or similar. -template -std::function& actual)> -ApproxVectorMatcher(const std::vector& expected, float abs_diff, - float rel_diff) { - return [abs_diff, rel_diff, - expected](const std::vector& actual) -> AssertionResult { - int len = expected.size(); - if (actual.size() != len) { - AssertionResult ar = AssertionFailure() << "Actual values len of " - << actual.size() - << " != expected.size " << len; - LOG(ERROR) << ar.message(); - return ar; - } - for (int i = 0; i < len; ++i) { - T diff = actual[i] - expected[i]; - if (diff < 0) { - diff *= -1; - } - if (diff > abs_diff) { - T rdiff = (expected[i] != 0 ? diff / expected[i] : 0.0 * expected[i]); - if (rdiff > rel_diff) { - AssertionResult ar = AssertionFailure() - << "Element " << i << " actual " << actual[i] - << " != " << expected[i] - << "( abs_diff = " << diff - << ", rel_diff = " << rdiff << ")"; - LOG(ERROR) << ar.message(); - return ar; - } - } - } - return AssertionSuccess(); - }; -} - -// Matches a vector of same-type values against another, succeeding so -// long as they have the same length and every value in 'actual' -// matches one in 'expected.' Does not verify an exhaustive -// one-to-one mapping between the two. -template -std::function& actual)> -UnorderedElementsAre(const std::vector& expected) { - return [expected](const std::vector& actual) -> AssertionResult { - if (actual.size() != expected.size()) { - return AssertionFailure() << "sizes don't match"; - } - for (auto a : actual) { - bool found = false; - for (auto e : expected) { - if (a == e) { - found = true; - break; - } - } - if (!found) { - return AssertionFailure() << "actual element " << a - << " not in expected"; - } - } - return AssertionSuccess(); - }; -} - -// Overloaded cover functions for UnorderedElementsAre, for the numbers -// of values used in practice. -template -std::function& actual)> UnorderedMatcher( - T a) { - std::vector expected; - expected.push_back(a); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b, T c) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b, T c, T d) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - expected.push_back(d); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b, T c, T d, T e) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - expected.push_back(d); - expected.push_back(e); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b, T c, T d, T e, T f) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - expected.push_back(d); - expected.push_back(e); - expected.push_back(f); - return testing::UnorderedElementsAre(expected); -} - -// Overloaded cover functions for VectorMatcher for the numbers of -// elements used in practice. -template -std::function& actual)> OrderedMatcher( - T a) { - std::vector expected; - expected.push_back(a); - return testing::VectorMatcher(expected); -} - -template -std::function& actual)> OrderedMatcher( - T a, T b) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - return testing::VectorMatcher(expected); -} - -template -std::function& actual)> OrderedMatcher( - T a, T b, T c) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - return testing::VectorMatcher(expected); -} - -template -std::function& actual)> OrderedMatcher( - T a, T b, T c, T d) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - expected.push_back(d); - return testing::VectorMatcher(expected); -} - -// Convert a RepeatedField to a flat vector. -template -std::vector PBToVec(const tensorflow::protobuf::RepeatedField rf) { - return std::vector(rf.begin(), rf.end()); -} - -// Convert a List to a flat vector. -template -std::vector ListToVec(const std::list& l) { - return std::vector(l.begin(), l.end()); -} - -// Convert a Set to a flat vector. -template -std::vector SetToVec(const std::set& c) { - return std::vector(c.begin(), c.end()); -} - -// Convert an Array to a flat vector. -template -std::vector Array2DToVec(const Array2D& a) { - return std::vector(a.data(), a.data() + a.num_elements()); -} - namespace internal_status { inline const ::tensorflow::Status& GetStatus( const ::tensorflow::Status& status) { @@ -347,9 +67,4 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { ASSERT_EQ(tensorflow::Status::OK(), \ xla::testing::internal_status::GetStatus(expression)) -// Macros that apply a Matcher to a Value, returning an -// AssertionResult which gets digested by a standard gunit macro. -#define EXPECT_MATCH(V, M) EXPECT_TRUE((M)((V))) -#define ASSERT_MATCH(V, M) ASSERT_TRUE(M(V)) - #endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 31d549cc421..13dd1a30b60 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -92,6 +93,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", @@ -101,8 +103,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_graph_dumper", - "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -149,7 +151,7 @@ cc_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -173,7 +175,6 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -203,6 +204,7 @@ cc_library( "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//third_party/eigen3", ], ) @@ -217,6 +219,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -236,6 +239,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", ], @@ -252,6 +256,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", ], @@ -271,6 +276,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -286,6 +292,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -309,6 +316,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -332,6 +340,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -346,7 +355,9 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", ], @@ -454,16 +465,18 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", - "//tensorflow/core:test", ], ) @@ -656,7 +669,9 @@ xla_test( }, shard_count = 30, deps = [ + "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", @@ -886,6 +901,7 @@ xla_test( name = "copy_test", srcs = ["copy_test.cc"], deps = [ + ":client_library_test_base", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", @@ -930,7 +946,6 @@ xla_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", - "//tensorflow/core:test", ], ) @@ -958,13 +973,13 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:test", ], ) @@ -1147,6 +1162,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1201,12 +1217,11 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -1234,27 +1249,6 @@ xla_test( ], ) -xla_test( - name = "inprocess_service_test", - srcs = ["inprocess_service_test.cc"], - deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - xla_test( name = "replay_test", srcs = ["replay_test.cc"], @@ -1344,6 +1338,22 @@ cc_test( ], ) +cc_test( + name = "hlo_metadata_test", + srcs = [ + "hlo_metadata_test.cc", + ], + deps = [ + ":local_client_test_base", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + xla_test( name = "round_trip_transfer_test", srcs = ["round_trip_transfer_test.cc"], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 23579088c9e..c07f2745fe9 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -27,14 +27,17 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -82,6 +85,50 @@ TEST_F(ArrayElementwiseOpTest, NegConstantS32) { {}); } +XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto result = builder.IsFinite(a); + + ComputeAndCompareR1(&builder, {}, {}); +} + +// A non-canonical quiet NaN value. +static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); + +XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { + ComputationBuilder builder(client_, TestName()); + auto result = builder.IsFinite(builder.ConstantR0(NAN)); + ComputeAndCompareR0(&builder, false, {}); + + EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); + auto result_non_canonical = + builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); + ComputeAndCompareR0(&builder, false, {}); + + const float inf = std::numeric_limits::infinity(); + auto result_inf = builder.IsFinite(builder.ConstantR0(inf)); + ComputeAndCompareR0(&builder, false, {}); + + auto result_neg_inf = builder.IsFinite(builder.ConstantR0(-inf)); + ComputeAndCompareR0(&builder, false, {}); + + auto result_zero = builder.IsFinite(builder.ConstantR0(0.0f)); + ComputeAndCompareR0(&builder, true, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { + ComputationBuilder builder(client_, TestName()); + const float inf = std::numeric_limits::infinity(); + EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); + auto a = builder.ConstantR1( + {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); + auto result = builder.IsFinite(a); + + ComputeAndCompareR1(&builder, {false, true, false, true, false, false}, + {}); +} + TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); @@ -197,6 +244,150 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { ComputeAndCompareR1(&builder, {}, {}, error_spec_); } +TEST_F(ArrayElementwiseOpTest, DivS32s) { + // clang-format off + // Some interesting values to test. + std::vector vals = { + INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff, + -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101, + 7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX}; + // clang-format on + + std::vector dividends, divisors, quotients, remainders; + for (int32 divisor : vals) { + if (divisor != 0) { + for (int32 dividend : vals) { + // Avoid integer overflow. + if (dividend != INT32_MIN || divisor != -1) { + dividends.push_back(dividend); + divisors.push_back(divisor); + quotients.push_back(dividend / divisor); + remainders.push_back(dividend % divisor); + } + } + } + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + builder.Div(dividend, divisor); + + ComputeAndCompareR1(&builder, quotients, + {dividend_data.get(), divisor_data.get()}); + } + + // Test with a compile-time constant divisor. + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + builder.Div(dividend, builder.ConstantR1(divisors)); + + ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + builder.Rem(dividend, divisor); + + ComputeAndCompareR1(&builder, remainders, + {dividend_data.get(), divisor_data.get()}); + } + + // Test with a compile-time constant divisor. + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + builder.Rem(dividend, builder.ConstantR1(divisors)); + + ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); + } +} + +TEST_F(ArrayElementwiseOpTest, DivU32s) { + // clang-format off + // Some interesting values to test. + std::vector vals = { + 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000, + 0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX}; + // clang-format on + + std::vector dividends, divisors, quotients, remainders; + for (uint32 divisor : vals) { + if (divisor != 0) { + for (uint32 dividend : vals) { + dividends.push_back(dividend); + divisors.push_back(divisor); + quotients.push_back(dividend / divisor); + remainders.push_back(dividend % divisor); + } + } + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", + &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + builder.Div(dividend, divisor); + + ComputeAndCompareR1(&builder, quotients, + {dividend_data.get(), divisor_data.get()}); + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", + &builder, ÷nd); + builder.Div(dividend, builder.ConstantR1(divisors)); + + ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", + &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + builder.Rem(dividend, divisor); + + ComputeAndCompareR1(&builder, remainders, + {dividend_data.get(), divisor_data.get()}); + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", + &builder, ÷nd); + builder.Rem(dividend, builder.ConstantR1(divisors)); + + ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); + } +} + XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1( @@ -441,6 +632,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { ComputeAndCompareR1(&builder, {}, {}); } +TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { + // Disable fast-math because we're operating on NaNs. + SetFastMathDisabled(true); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); +} + TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); @@ -575,12 +778,14 @@ TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN}); + auto lhs = + builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); + auto rhs = + builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); auto minimum = builder.Pow(lhs, rhs); - ComputeAndCompareR1(&builder, {16.0f, 0.25f, 8.0f, NAN, NAN}, {}, - error_spec_); + ComputeAndCompareR1( + &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { @@ -625,6 +830,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); std::vector values; + values.reserve(count); for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } @@ -632,6 +838,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); std::vector expected; + expected.reserve(values.size()); for (float value : values) { expected.push_back(value * value); } @@ -1584,7 +1791,7 @@ TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } -TEST_F(ArrayElementwiseOpTest, R4_32x64x2x2_Plus_R1_64) { +TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { constexpr int d0 = 16; constexpr int d1 = 16; constexpr int d2 = 2; @@ -1622,9 +1829,9 @@ TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { auto concatenated = builder.Add(x, x); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); - EXPECT_MATCH(computation_status.status().ToString(), - testing::ContainsRegex( - "Expected non-opaque argument for lhs of binary operation")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::ContainsRegex( + "Expected non-opaque argument for lhs of binary operation")); } // Regression test for b/31927799. "slice - y" is fused and requires implicit @@ -1638,7 +1845,7 @@ TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { auto x = builder.Parameter(0, x_literal->shape(), "x"); auto y = builder.Parameter(1, y_literal->shape(), "y"); - auto slice = builder.Slice(x, {1}, {2}); + auto slice = builder.Slice(x, {1}, {2}, {1}); builder.Sub(slice, y); ComputeAndCompareR1(&builder, {-2, -3}, {x_data.get(), y_data.get()}, @@ -1654,7 +1861,9 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index adffac09e36..a1ca1de584f 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -74,6 +75,7 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index c7b533b80f1..ea58491038c 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -22,13 +22,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -45,8 +45,8 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { StatusOr computation = builder.Build(); EXPECT_FALSE(computation.ok()); LOG(INFO) << "status received: " << computation.status(); - EXPECT_MATCH(computation.status().error_message(), - testing::HasSubstr("shape has invalid")); + EXPECT_THAT(computation.status().error_message(), + ::testing::HasSubstr("shape has invalid")); } TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { @@ -69,6 +69,7 @@ TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 598fd69909b..6a47f1b718a 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -194,6 +195,7 @@ TEST_F(BatchNormalizationTest, SpecComparisonForward) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index e825bd435b6..5e3b70702dd 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -141,6 +142,7 @@ TEST_F(BinopScalingTest, R4PlusR0S32) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 200d4d45634..25fe04a930e 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -22,18 +22,92 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { -using BroadcastSimpleTest = ClientLibraryTestBase; +class BroadcastSimpleTest : public ClientLibraryTestBase { + public: + ComputationDataHandle BuildBinOp(HloOpcode op, + const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs, + ComputationBuilder* builder) { + switch (op) { + case HloOpcode::kMinimum: { + return builder->Min(lhs, rhs); + } + case HloOpcode::kMaximum: { + return builder->Max(lhs, rhs); + } + case HloOpcode::kMultiply: { + return builder->Mul(lhs, rhs); + } + default: { + // Default to Add + return builder->Add(lhs, rhs); + } + } + } + + std::unique_ptr MakeR3Data( + tensorflow::gtl::ArraySlice bounds, + tensorflow::gtl::ArraySlice minor_to_major, Shape* r3_shape, + Array3D* r3_array, float start, float end, int seed) { + *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); + r3_array->FillRandom(start, end, seed); + auto r3_data = + LiteralUtil::Relayout(*LiteralUtil::CreateR3FromArray3D(*r3_array), + LayoutUtil::MakeLayout(minor_to_major)); + std::unique_ptr r3_global_data = + client_->TransferToServer(*r3_data).ConsumeValueOrDie(); + return r3_global_data; + } + + std::unique_ptr MakeR2Data( + tensorflow::gtl::ArraySlice bounds, + tensorflow::gtl::ArraySlice minor_to_major, Shape* r2_shape, + Array2D* r2_array, float start, float end, int seed) { + *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); + r2_array->FillRandom(start, end, seed); + auto r2_data = + LiteralUtil::Relayout(*LiteralUtil::CreateR2FromArray2D(*r2_array), + LayoutUtil::MakeLayout(minor_to_major)); + std::unique_ptr r2_global_data = + client_->TransferToServer(*r2_data).ConsumeValueOrDie(); + return r2_global_data; + } + + float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) { + switch (op) { + case HloOpcode::kMinimum: { + return std::min(lhs, rhs); + } + case HloOpcode::kMaximum: { + return std::max(lhs, rhs); + } + case HloOpcode::kMultiply: { + return lhs * rhs; + } + case HloOpcode::kAdd: { + return lhs + rhs; + } + default: { + // Default to Add + CHECK(false); + } + } + } +}; + +using ::testing::HasSubstr; XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { ComputationBuilder b(client_, TestName()); @@ -48,6 +122,19 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { + ComputationBuilder b(client_, TestName()); + ComputationDataHandle src; + std::unique_ptr param_data = + CreateR0Parameter(2.25f, /*parameter_number=*/0, /*name=*/"src", + /*builder=*/&b, /*data_handle=*/&src); + + b.Broadcast(src, {2, 3}); + Array2D expected(2, 3, 2.25); + ComputeAndCompareR2(&b, expected, {param_data.get()}, + ErrorSpec(0.0001)); +} + XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { ComputationBuilder b(client_, TestName()); b.Broadcast(b.ConstantR0(2.25), {2, 0}); @@ -76,6 +163,33 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } +// Tests implicit broadcasting of PREDs. +XLA_TEST_F(BroadcastSimpleTest, LogicalAnd2DTo3D_Pred) { + ComputationBuilder b(client_, TestName()); + + Array2D x_vals(2, 1); + x_vals(0, 0) = true; + x_vals(1, 0) = false; + Array3D y_vals(2, 2, 1); + y_vals(0, 0, 0) = false; + y_vals(0, 1, 0) = false; + y_vals(1, 0, 0) = true; + y_vals(1, 1, 0) = true; + + ComputationDataHandle x, y; + auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); + auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); + b.LogicalAnd(x, y, /*broadcast_dimensions=*/{1, 2}); + + Array3D expected(2, 2, 1); + expected(0, 0, 0) = false; + expected(0, 1, 0) = false; + expected(1, 0, 0) = true; + expected(1, 1, 0) = false; + + ComputeAndCompareR3(&b, expected, {x_data.get(), y_data.get()}); +} + XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { ComputationBuilder b(client_, TestName()); b.Broadcast(b.ConstantR1({}), {2}); @@ -114,6 +228,434 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } +struct R3ImplicitBroadcastSpec { + std::array output_bounds; + std::array minor2major_layout; + std::array input_bounds; + HloOpcode op; +} kR3ImplicitBroadcastTestCases[] = { + {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd}, + {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum}, + {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd}, +}; + +class BroadcastR3ImplicitTest + : public BroadcastSimpleTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { + const R3ImplicitBroadcastSpec& spec = GetParam(); + ComputationBuilder builder(client_, TestName()); + + Shape r3_shape, r3_implicit_shape; + Array3D r3_array(spec.output_bounds[0], spec.output_bounds[1], + spec.output_bounds[2]); + Array3D r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1], + spec.input_bounds[2]); + + std::unique_ptr r3_global_data = + MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape, + &r3_array, 1.0, 2.5, 56789); + std::unique_ptr r3_implicit_global_data = + MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape, + &r3_implicit_array, 1.0, 0.2, 56789); + + auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); + auto r3_parameter = builder.Parameter(1, r3_shape, "input"); + ComputationDataHandle op = + BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); + + Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], + spec.output_bounds[2]); + auto Each = ([&](tensorflow::gtl::ArraySlice indices, float* value) { + float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0], + indices[1] % spec.input_bounds[1], + indices[2] % spec.input_bounds[2]); + float r3 = r3_array(indices[0], indices[1], indices[2]); + *value = ApplyOpToFloats(spec.op, r3_implicit, r3); + }); + + int n1 = expected_array.n1(); + int n2 = expected_array.n2(); + int n3 = expected_array.n3(); + for (int64 i = 0; i < n1; i++) { + for (int64 j = 0; j < n2; j++) { + for (int64 k = 0; k < n3; k++) { + Each({i, j, k}, &expected_array(i, j, k)); + } + } + } + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); + ComputeAndCompareLiteral( + &builder, *expected, + {r3_implicit_global_data.get(), r3_global_data.get()}, + ErrorSpec(1e-7, 1e-7)); +} + +INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances, + BroadcastR3ImplicitTest, + ::testing::ValuesIn(kR3ImplicitBroadcastTestCases)); + +// r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1: +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { + ComputationBuilder b(client_, TestName()); + ComputationDataHandle r1h; + ComputationDataHandle r3h; + + Array3D r1d = {{{1}}, {{2}}}; + Array3D r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}; + auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h); + auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h); + + b.Add(r3h, r1h); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + + ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, + ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}, {2}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { + ComputationBuilder b(client_, TestName()); + auto r1 = + b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { + ComputationBuilder b(client_, TestName()); + auto r1 = + b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +struct R2ImplicitBroadcastSpec { + std::array output_bounds; + std::array minor2major_layout; + std::array input_bounds1; + std::array input_bounds2; + HloOpcode op1; + HloOpcode op2; +} kR2ImplicitBroadcastTestCases[] = { + {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd}, + {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd}, + {{{2, 3}}, + {{1, 0}}, + {{2, 1}}, + {{1, 1}}, + HloOpcode::kAdd, + HloOpcode::kMinimum}, + {{{2, 3}}, + {{1, 0}}, + {{1, 3}}, + {{1, 1}}, + HloOpcode::kAdd, + HloOpcode::kMinimum}, + {{{2, 3}}, + {{1, 0}}, + {{1, 1}}, + {{1, 1}}, + HloOpcode::kAdd, + HloOpcode::kMinimum}, + {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd}, + {{{150, 150}}, + {{1, 0}}, + {{150, 1}}, + {{150, 1}}, + HloOpcode::kAdd, + HloOpcode::kAdd}, + {{{150, 150}}, + {{1, 0}}, + {{150, 1}}, + {{1, 150}}, + HloOpcode::kAdd, + HloOpcode::kAdd}, + {{{150, 150}}, + {{1, 0}}, + {{150, 1}}, + {{1, 1}}, + HloOpcode::kAdd, + HloOpcode::kAdd}, + {{{50, 150}}, + {{1, 0}}, + {{50, 1}}, + {{50, 1}}, + HloOpcode::kAdd, + HloOpcode::kAdd}, + {{{50, 150}}, + {{1, 0}}, + {{50, 1}}, + {{1, 150}}, + HloOpcode::kAdd, + HloOpcode::kAdd}, + {{{50, 150}}, + {{1, 0}}, + {{50, 1}}, + {{1, 1}}, + HloOpcode::kAdd, + HloOpcode::kAdd}, + {{{150, 50}}, + {{1, 0}}, + {{150, 1}}, + {{150, 1}}, + HloOpcode::kAdd, + HloOpcode::kAdd}, + {{{150, 50}}, + {{1, 0}}, + {{150, 1}}, + {{1, 50}}, + HloOpcode::kAdd, + HloOpcode::kAdd}, + {{{150, 50}}, + {{1, 0}}, + {{150, 1}}, + {{1, 1}}, + HloOpcode::kAdd, + HloOpcode::kAdd}}; + +class BroadcastR2ImplicitTest + : public BroadcastSimpleTest, + public ::testing::WithParamInterface {}; + +// Test r2 op1 r2_implicit_1 op2 r2_implicit_2 +// where R2 is a rank-2 operand, and r2_implicit_2 are two +// rank-2 operands with degenerate dimensions: +XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { + const R2ImplicitBroadcastSpec& spec = GetParam(); + + ComputationBuilder builder(client_, TestName()); + + // Operands with degenerate dimensions require implicit broadcasting: + Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2; + Array2D r2_array(spec.output_bounds[0], spec.output_bounds[1]); + Array2D r2_implicit_array1(spec.input_bounds1[0], + spec.input_bounds1[1]); + Array2D r2_implicit_array2(spec.input_bounds2[0], + spec.input_bounds2[1]); + + std::unique_ptr r2_global_data = + MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape, + &r2_array, 1.0, 2.5, 56789); + std::unique_ptr r2_implicit_global_data1 = + MakeR2Data(spec.input_bounds1, spec.minor2major_layout, + &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789); + std::unique_ptr r2_implicit_global_data2 = + MakeR2Data(spec.input_bounds2, spec.minor2major_layout, + &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789); + + auto r2_implicit_parameter1 = + builder.Parameter(0, r2_implicit_shape1, "input0"); + auto r2_parameter = builder.Parameter(1, r2_shape, "input1"); + auto r2_implicit_parameter2 = + builder.Parameter(2, r2_implicit_shape2, "input2"); + + ComputationDataHandle op1 = + BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder); + ComputationDataHandle op2 = + BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); + + Array2D expected_array(spec.output_bounds[0], spec.output_bounds[1]); + + expected_array.Each([&](int64 i, int64 j, float* v) { + float v1 = r2_implicit_array1(i % spec.input_bounds1[0], + j % spec.input_bounds1[1]); + float v2 = r2_array(i, j); + float v3 = r2_implicit_array2(i % spec.input_bounds2[0], + j % spec.input_bounds2[1]); + float tmp = ApplyOpToFloats(spec.op1, v1, v2); + *v = ApplyOpToFloats(spec.op2, tmp, v3); + }); + + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); + ComputeAndCompareLiteral( + &builder, *expected, + {r2_implicit_global_data1.get(), r2_global_data.get(), + r2_implicit_global_data2.get()}, + ErrorSpec(1e-6, 1e-6)); +} + +INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, + BroadcastR2ImplicitTest, + ::testing::ValuesIn(kR2ImplicitBroadcastTestCases)); + +XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}})); + auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + b.Add(r2, r1); + + auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + b.Add(r2, r1); + + auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantR1({10, 20}); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1, {0}); + + auto expected = LiteralUtil::CreateR3( + {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantR1({10, 20}); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r1, r3, {1}); + + auto expected = LiteralUtil::CreateR3( + {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantR1({10, 20}); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r1, r3, {2}); + + auto expected = LiteralUtil::CreateR3( + {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { + ComputationBuilder b(client_, TestName()); + auto r1_0 = b.ConstantR1({1000, 2000}); + auto r1_1 = b.ConstantR1({100, 200}); + auto r1_2 = b.ConstantR1({10, 20}); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + for (int i = 0; i < 3; ++i) { + r3 = b.Add(r1_0, r3, {0}); + r3 = b.Add(r3, r1_1, {1}); + r3 = b.Add(r1_2, r3, {2}); + } + r3 = b.Mul(r3, b.ConstantR0(-2)); + + auto expected = LiteralUtil::CreateR3( + {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, + {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { + ComputationBuilder b(client_, TestName()); + auto r1_0 = b.ConstantR1({1000, 2000}); + auto r1_1 = b.ConstantR1({100, 200}); + auto r1_2 = b.ConstantR1({10, 20}); + auto r0 = b.ConstantR0(3); + auto r3 = b.Broadcast(r0, {2, 2, 2}); + for (int i = 0; i < 3; ++i) { + r3 = b.Add(r1_0, r3, {0}); + r3 = b.Add(r3, r1_1, {1}); + r3 = b.Add(r1_2, r3, {2}); + } + r3 = b.Mul(r3, b.ConstantR0(-1)); + + auto expected = LiteralUtil::CreateR3( + {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, + {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2]) // results in a shape incompatible with the lhs [2, 3, 1]. @@ -126,8 +668,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH(result_status.status().error_message(), - testing::ContainsRegex("broadcast dimension 0 mismatch")); + EXPECT_THAT(result_status.status().error_message(), + HasSubstr("broadcast dimension 0 mismatch")); } XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { @@ -139,9 +681,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH( - result_status.status().error_message(), - testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes")); + EXPECT_THAT(result_status.status().error_message(), + HasSubstr("binary op BINOP_ADD with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -153,9 +694,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH( - result_status.status().error_message(), - testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes")); + EXPECT_THAT(result_status.status().error_message(), + HasSubstr("binary op BINOP_ADD with incompatible shapes")); } } // namespace @@ -163,7 +703,9 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 1796a732e54..96a329a9bd8 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -43,7 +44,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { ShapeUtil::MakeShape(F32, {}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -59,7 +60,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -82,7 +83,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { builder.AddInstruction(HloInstruction::CreateTuple({element1, element2})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -103,7 +104,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -122,7 +123,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -139,7 +140,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -159,7 +160,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -184,12 +185,12 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); Array4D expected(3, 3, 3, 1025); - Array2D yx(/*height=*/3, /*width=*/r1_size); + Array2D yx(3, r1_size); for (int64 y = 0; y < 3; ++y) { for (int64 x = 0; x < r1_size; ++x) { yx(y, x) = input_data[x]; @@ -215,7 +216,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -231,7 +232,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); LOG(INFO) << hlo_module->ToString(); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -254,7 +255,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3})); // Create HLO module, compile, and execute. - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -265,12 +266,44 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); } +TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { + auto builder = HloComputation::Builder(TestName()); + Array3D input_vals(2, 3, 4); + input_vals.FillRandom(1.0); + + Array4D expected(2, 3, 4, 5); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 4; ++k) { + for (int m = 0; m < 5; ++m) { + expected(i, j, k, m) = input_vals(i, j, k); + } + } + } + } + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR3FromArray3D(input_vals))); + + // Broadcast vector in dimensions 2 and 3. + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); + + // Create HLO module, compile, and execute. + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 2c7eeb820d3..1f61743451a 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -102,7 +102,7 @@ def xla_test(name, elif backend == "cpu_parallel": backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] - this_backend_args += ["--xla_cpu_parallel=true"] + this_backend_args += ["--xla_backend_extra_options=\"xla_cpu_parallel\""] elif backend == "gpu": backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 0b5e6d51277..55701c62db2 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -117,6 +119,7 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 675c9fccb00..4825eaf19dc 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -19,18 +19,21 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { +using ::testing::ContainsRegex; + class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { @@ -60,15 +63,15 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { ASSERT_FALSE(result_one_arg.ok()); ASSERT_EQ(result_one_arg.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(result_one_arg.status().error_message(), - testing::ContainsRegex("takes 2")); + ASSERT_THAT(result_one_arg.status().error_message(), + ContainsRegex("takes 2")); auto result_zero_args = client_->Execute(computation, {}); ASSERT_FALSE(result_zero_args.ok()); ASSERT_EQ(result_zero_args.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(result_zero_args.status().error_message(), - testing::ContainsRegex("takes 2")); + ASSERT_THAT(result_zero_args.status().error_message(), + ContainsRegex("takes 2")); } XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { @@ -99,22 +102,22 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(status.status().error_message(), - testing::ContainsRegex("expects parameter 0")); + ASSERT_THAT(status.status().error_message(), + ContainsRegex("expects parameter 0")); // Shape mismatch in parameter 1 (rank) status = client_->Execute(computation, {f32_data.get(), f32_data.get()}); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(status.status().error_message(), - testing::ContainsRegex("expects parameter 1")); + ASSERT_THAT(status.status().error_message(), + ContainsRegex("expects parameter 1")); // Shape mismatch in parameter 1 (element type) status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(status.status().error_message(), - testing::ContainsRegex("expects parameter 1")); + ASSERT_THAT(status.status().error_message(), + ContainsRegex("expects parameter 1")); } } // namespace @@ -122,6 +125,7 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 7bf1168dc39..b96bb8f8469 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -44,14 +44,19 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) { } } // namespace -ClientLibraryTestBase::ClientLibraryTestBase( - se::Platform* platform, - tensorflow::gtl::ArraySlice disabled_pass_names) +ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) : client_(GetOrCreateLocalClientOrDie(platform)) { - legacy_flags::HloPassPipelineFlags* flags = - legacy_flags::GetHloPassPipelineFlags(); - flags->xla_disable_hlo_passes = - tensorflow::str_util::Join(disabled_pass_names, ","); + *(execution_options_.mutable_debug_options()) = + legacy_flags::GetDebugOptionsFromFlags(); + + // Disabling constant_folding so that tests (usually written using Constants) + // will exercise the intended code paths, instead of being constant folded. + // + // TODO(b/38354253): Constant folding is currently disabled. Change tests to + // use Parameters instead of Constants, and re-enable constant folding by + // default. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "constant_folding"); } string ClientLibraryTestBase::TestName() const { @@ -179,7 +184,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal); VLOG(1) << "actual: " << LiteralUtil::ToString(*actual); - EXPECT_EQ(expected, actual->u8s()); + EXPECT_EQ(expected, actual->u8s_string()); } void ClientLibraryTestBase::ComputeAndCompareTuple( diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 026f487c2df..f9e1082ebb4 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -46,14 +46,22 @@ namespace xla { class ClientLibraryTestBase : public ::testing::Test { protected: explicit ClientLibraryTestBase( - perftools::gputools::Platform* platform = nullptr, - tensorflow::gtl::ArraySlice disabled_pass_names = {}); + perftools::gputools::Platform* platform = nullptr); // Returns the name of the test currently being run. string TestName() const; void SetFastMathDisabled(bool disabled) { - execution_options_.set_disable_fast_math(disabled); + execution_options_.mutable_debug_options()->set_xla_enable_fast_math( + !disabled); + } + + void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } + + // Provides mutable access to the execution DebugOptions field; this lets + // tests tweak the options that will be used to compile/run the graph. + DebugOptions* mutable_debug_options() { + return execution_options_.mutable_debug_options(); } // TODO(b/25566808): Add helper that populates a literal from a testdata file. @@ -216,6 +224,16 @@ class ClientLibraryTestBase : public ::testing::Test { const int rows, const int cols, const int rows_padded, const int cols_padded); + // Create a parameter instruction that wraps a given value and then stores + // into "data_handle" the global handle for that parameter. + // + // "parameter_number" is the parameter number. + // "name" is the name of the parameter instruction. + template + std::unique_ptr CreateR0Parameter( + NativeT value, int64 parameter_number, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle); + // Create a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. // @@ -370,6 +388,17 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments, error); } +template +std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( + NativeT value, int64 parameter_number, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle) { + std::unique_ptr literal = LiteralUtil::CreateR0(value); + std::unique_ptr data = + client_->TransferToServer(*literal).ConsumeValueOrDie(); + *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + return data; +} + template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 86ce636ee56..1247804dae0 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -113,6 +114,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index 81c0568ff92..cc3eb0e8d46 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" @@ -43,33 +42,33 @@ void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr hlo_module, std::unique_ptr CodegenTestBase::CompileToExecutable( std::unique_ptr hlo_module) { - auto module_config = MakeUnique( - hlo_module->entry_computation()->ComputeProgramShape()); - module_config->set_fast_math_disabled(fast_math_disabled_); return backend_->compiler() - ->Compile(std::move(hlo_module), std::move(module_config), - test_hlo_dumper_, backend_->default_stream_executor()) + ->Compile(std::move(hlo_module), test_hlo_dumper_, + backend_->default_stream_executor()) .ConsumeValueOrDie(); } void CodegenTestBase::RunFileCheck(const string& input, const string& pattern) { + using tensorflow::io::JoinPath; + // Write input to a temporary file. char tempdir_template[] = "/tmp/ir_testXXXXXX"; char* tempdir_name = mkdtemp(tempdir_template); CHECK_NOTNULL(tempdir_name); - string pattern_path = - tensorflow::io::JoinPath(tempdir_name, "xla_hlo_test_ir_pattern"); + string pattern_path = JoinPath(tempdir_name, "xla_hlo_test_ir_pattern"); TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), pattern_path, pattern)); // Invoke FileCheck to check whether input matches `pattern`. - tensorflow::SubProcess file_check_process; - const char* test_srcdir = getenv("TEST_SRCDIR"); - if (test_srcdir == nullptr) { - test_srcdir = "."; + const char* file_check_path_suffix = "external/llvm/FileCheck"; + string file_check_path; + if (const char* test_srcdir = getenv("TEST_SRCDIR")) { + file_check_path = JoinPath(test_srcdir, file_check_path_suffix); + } else { + file_check_path = file_check_path_suffix; } - string file_check_path = tensorflow::io::JoinPath( - test_srcdir, "external/llvm/FileCheck"); + + tensorflow::SubProcess file_check_process; file_check_process.SetProgram(file_check_path, {file_check_path, pattern_path}); file_check_process.SetChannelAction(tensorflow::CHAN_STDIN, diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.h b/tensorflow/compiler/xla/tests/codegen_test_base.h index ba32aac8e4b..50c04531070 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.h +++ b/tensorflow/compiler/xla/tests/codegen_test_base.h @@ -41,9 +41,6 @@ class CodegenTestBase : public HloTestBase { void CompileAndVerifyIr(std::unique_ptr hlo_module, const string& pattern); - // Sets the fast-math-disabled flag on the config we use when compiling. - void set_fast_math_disabled(bool disabled) { fast_math_disabled_ = disabled; } - protected: // Compiles hlo_module to an executable, CHECK-failing if this fails. std::unique_ptr CompileToExecutable( @@ -52,8 +49,6 @@ class CodegenTestBase : public HloTestBase { // Runs FileCheck with the given pattern over the given string and EXPECTs // that FileCheck succeeded in matching the input. void RunFileCheck(const string& input, const string& pattern); - - bool fast_math_disabled_ = false; }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 1d0df615824..18ea9714d1a 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -203,6 +204,7 @@ XLA_TEST_F(CompilationCacheTest, MutatedComputation) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 709ce5029c8..13c78fb1633 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -17,43 +17,75 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -class ComputeConstantTest : public ClientLibraryTestBase { +// An enumerator for the client types that we want to iterate over in +// the various tests. +enum class ClientType { kLocal, kCompileOnly }; +ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly}; + +class ComputeConstantTest : public ::testing::Test { public: + explicit ComputeConstantTest( + perftools::gputools::Platform* platform = nullptr) + : platform_(platform) {} + + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + Client* ClientOrDie(::perftools::gputools::Platform* platform, + ClientType client_type) { + if (client_type == ClientType::kLocal) { + StatusOr result = + ClientLibrary::GetOrCreateLocalClient(platform); + TF_CHECK_OK(result.status()) + << "could not create LocalClient for testing"; + return result.ValueOrDie(); + } else if (client_type == ClientType::kCompileOnly) { + StatusOr result = + ClientLibrary::GetOrCreateCompileOnlyClient(platform); + TF_CHECK_OK(result.status()) + << "could not create CompileOnlyClient for testing"; + return result.ValueOrDie(); + } + LOG(FATAL) << "invalid client_type value"; + } + StatusOr> ComputeConstantLiteral( - ComputationDataHandle operand, ComputationBuilder* builder, - Layout* output_layout = nullptr) { + Client* client, const ComputationDataHandle& operand, + ComputationBuilder* builder, Layout* output_layout = nullptr) { TF_ASSIGN_OR_RETURN(auto remote_computed, builder->ComputeConstant(operand, output_layout)); - TF_ASSIGN_OR_RETURN(auto computed, client_->Transfer(*remote_computed)); + TF_ASSIGN_OR_RETURN(auto computed, client->Transfer(*remote_computed)); return std::move(computed); } template - StatusOr ComputeConstantScalar(ComputationDataHandle operand, + StatusOr ComputeConstantScalar(Client* client, + const ComputationDataHandle& operand, ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(operand, builder)); + TF_ASSIGN_OR_RETURN(auto literal, + ComputeConstantLiteral(client, operand, builder)); return LiteralUtil::Get(*literal, {}); } @@ -64,140 +96,162 @@ class ComputeConstantTest : public ClientLibraryTestBase { return result.ok() ? result.ValueOrDie() : false; } - template - void ExpectConstantComputedScalar(ComputationDataHandle operand, - Scalar expected, - ComputationBuilder* builder) { - Scalar computed = ComputeConstantScalar(operand, builder); - ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(expected); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); - } + perftools::gputools::Platform* platform_; }; TEST_F(ComputeConstantTest, ScalarInt32Literal) { - ComputationBuilder b(client_, TestName()); - auto computation = b.ConstantR0(42); - EXPECT_TRUE(IsConstant(computation, &b)); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = b.ConstantR0(42); + EXPECT_TRUE(IsConstant(computation, &b)); - auto value = ComputeConstantScalar(computation, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 42); + auto value = ComputeConstantScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 42); + } } TEST_F(ComputeConstantTest, ScalarFloatAdd) { - ComputationBuilder b(client_, TestName()); - auto computation = - b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); - EXPECT_TRUE(IsConstant(computation, &b)); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = + b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); + EXPECT_TRUE(IsConstant(computation, &b)); - auto value = ComputeConstantScalar(computation, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 44.0f); + auto value = ComputeConstantScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 44.0f); + } } TEST_F(ComputeConstantTest, ScalarRng) { - ComputationBuilder b(client_, TestName()); - auto computation = - b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), - ShapeUtil::MakeShape(F32, {})); - EXPECT_FALSE(IsConstant(computation, &b)); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = + b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), + ShapeUtil::MakeShape(F32, {})); + EXPECT_FALSE(IsConstant(computation, &b)); - auto value = ComputeConstantScalar(computation, &b); - ASSERT_FALSE(value.ok()) - << "computing a RNG value should not be considered a constant"; + auto value = ComputeConstantScalar(client, computation, &b); + ASSERT_FALSE(value.ok()) + << "computing a RNG value should not be considered a constant"; + } } TEST_F(ComputeConstantTest, DirectParam) { - ComputationBuilder b(client_, TestName()); - auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); - EXPECT_FALSE(IsConstant(computation, &b)); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + EXPECT_FALSE(IsConstant(computation, &b)); - auto value = ComputeConstantScalar(computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) - << value.status(); + auto value = ComputeConstantScalar(client, computation, &b); + EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) + .contains("depends on parameter")) + << value.status(); + } } TEST_F(ComputeConstantTest, IndirectParam) { - ComputationBuilder b(client_, TestName()); - auto computation = - b.Add(b.ConstantR0(1.0f), - b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); - EXPECT_FALSE(IsConstant(computation, &b)); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = + b.Add(b.ConstantR0(1.0f), + b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + EXPECT_FALSE(IsConstant(computation, &b)); - auto value = ComputeConstantScalar(computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) - << value.status(); + auto value = ComputeConstantScalar(client, computation, &b); + EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) + .contains("depends on parameter")) + << value.status(); + } } // Test computation of an expression interspersed with param nodes but // the expression does not depend on the param nodes. TEST_F(ComputeConstantTest, UnrelatedParam) { - ComputationBuilder b(client_, TestName()); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); - auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); - auto constant_4 = b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); - auto not_constant_a = b.Add(constant_4, param_a); + auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto constant_4 = + b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); + auto not_constant_a = b.Add(constant_4, param_a); - auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); - auto constant_9 = b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); - auto not_constant_b = b.Add(param_b, constant_9); + auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto constant_9 = + b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); + auto not_constant_b = b.Add(param_b, constant_9); - auto constant_13 = b.Add(constant_4, constant_9); - b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + auto constant_13 = b.Add(constant_4, constant_9); + b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); - EXPECT_TRUE(IsConstant(constant_13, &b)); + EXPECT_TRUE(IsConstant(constant_13, &b)); - auto value = ComputeConstantScalar(constant_13, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 13.0f); + auto value = ComputeConstantScalar(client, constant_13, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 13.0f); + } } TEST_F(ComputeConstantTest, NonScalarAdd) { - ComputationBuilder b(client_, TestName()); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); - auto computation = - b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); - EXPECT_TRUE(IsConstant(computation, &b)); + auto computation = + b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); + EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + auto computed = ComputeConstantLiteral(client, computation, &b); + ASSERT_TRUE(computed.ok()) << computed.status(); + std::unique_ptr expected_literal = + LiteralUtil::CreateR1({4, 6}); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + } } TEST_F(ComputeConstantTest, IntegerDivide) { - ComputationBuilder b(client_, TestName()); - auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); - EXPECT_TRUE(IsConstant(computation, &b)); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); + EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + auto computed = ComputeConstantLiteral(client, computation, &b); + ASSERT_TRUE(computed.ok()) << computed.status(); + std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + } } XLA_TEST_F(ComputeConstantTest, Layout) { - ComputationBuilder b(client_, TestName()); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); - std::vector> layouts = {{0, 1}, {1, 0}}; - for (const std::vector& layout : layouts) { - auto layout_proto = LayoutUtil::MakeLayout(layout); - auto computed = - ComputeConstantLiteral(b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), - &b, &layout_proto); - ASSERT_TRUE(computed.ok()) << computed.status(); + std::vector> layouts = {{0, 1}, {1, 0}}; + for (const std::vector& layout : layouts) { + auto layout_proto = LayoutUtil::MakeLayout(layout); + auto computed = ComputeConstantLiteral( + client, + b.Add(b.ConstantR2({{1, 2}, {3, 4}}), + b.ConstantR2({{10, 20}, {30, 40}})), + &b, &layout_proto); + ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = - test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, - layout); - LiteralTestUtil::AssertEqualShapesAndLayouts( - expected_literal->shape(), computed.ValueOrDie()->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + std::unique_ptr expected_literal = + test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, + layout); + LiteralTestUtil::AssertEqualShapesAndLayouts( + expected_literal->shape(), computed.ValueOrDie()->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + } } } @@ -207,25 +261,28 @@ XLA_TEST_F(ComputeConstantTest, Layout) { TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) { // Compute a trivial constant, then try to use the value in an Execute // call. This should fail because the constant resides on the CPU and the - // Execute call is executed on a different backend. - ComputationBuilder constant_b(client_, TestName()); + // Execute call is executed on a different backend. This test only makes + // sense with LocalClient, since CompileOnlyClient does not support + // execution. + Client* client = ClientOrDie(platform_, ClientType::kLocal); + ComputationBuilder constant_b(client, TestName()); auto constant = constant_b.ConstantR0(42); auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie(); - auto literal = client_->Transfer(*handle).ConsumeValueOrDie(); + auto literal = client->Transfer(*handle).ConsumeValueOrDie(); LiteralTestUtil::ExpectR0Equal(42, *literal); // Build trivial computation which takes one parameter. - ComputationBuilder b(client_, TestName()); + ComputationBuilder b(client, TestName()); b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0")); auto computation = b.Build().ConsumeValueOrDie(); // Try to use value from ComputeConstant in Execute. - auto execute_status = client_->Execute(computation, {handle.get()}); + auto execute_status = client->Execute(computation, {handle.get()}); EXPECT_FALSE(execute_status.ok()); - EXPECT_MATCH( + EXPECT_THAT( execute_status.status().error_message(), - testing::ContainsRegex("argument 0 is on device Host:0 but computation " - "will be executed on device")); + ::testing::ContainsRegex("argument 0 is on device Host:0 but computation " + "will be executed on device")); } } // namespace @@ -233,6 +290,7 @@ TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 9a48b19b96a..a7034930bc9 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -22,8 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -34,6 +36,7 @@ namespace xla { namespace { using ConcatTest = ClientLibraryTestBase; +using ::testing::HasSubstr; // Concatenate expects at least one argument. XLA_TEST_F(ConcatTest, Concat_Nothing) { @@ -41,9 +44,8 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) { auto concatenated = builder.ConcatInDim({}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); - EXPECT_MATCH( - computation_status.status().ToString(), - testing::ContainsRegex("Concatenate expects at least one argument")); + EXPECT_THAT(computation_status.status().ToString(), + HasSubstr("Concatenate expects at least one argument")); } // Concatenate with one argument works. @@ -56,6 +58,15 @@ XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto concatenated = builder.ConcatInDim({a}, 0); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + // Show that we can't concatenate R0 with R0 because we can't name the dimension // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { @@ -65,9 +76,8 @@ XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { auto concatenated = builder.ConcatInDim({a, b}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); - EXPECT_MATCH(computation_status.status().ToString(), - testing::ContainsRegex( - "dimension to concatenate along out of bounds: 0")); + EXPECT_THAT(computation_status.status().ToString(), + HasSubstr("dimension to concatenate along out of bounds: 0")); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { @@ -404,10 +414,9 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { auto concatenated = builder.ConcatInDim({x, y}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); - EXPECT_MATCH( + EXPECT_THAT( computation_status.status().ToString(), - testing::ContainsRegex( - "Expected non-opaque argument for operand of concatenation")); + HasSubstr("Expected non-opaque argument for operand of concatenation")); } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { @@ -434,6 +443,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { + ComputationBuilder builder(client_, TestName()); + + Array3D arr0(9, 17, 1); + arr0.Fill(1); + + Array3D arr1(9, 17, 256); + arr1.Fill(2); + + Array3D expected(9, 17, arr0.n3() + arr1.n3()); + for (int64 i = 0; i < expected.n1(); ++i) { + for (int64 j = 0; j < expected.n2(); ++j) { + int64 kk = 0; + for (const Array3D& arr : {arr0, arr1}) { + for (int64 k = 0; k < arr.n3(); ++k, ++kk) { + expected(i, j, kk) = arr(i, j, k); + } + } + } + } + + ComputationDataHandle h0; + auto p0 = CreateR3Parameter(arr0, /*parameter_number=*/0, "p0", + &builder, &h0); + ComputationDataHandle h1; + auto p1 = CreateR3Parameter(arr1, /*parameter_number=*/1, "p1", + &builder, &h1); + + auto concatenated = builder.ConcatInDim({h0, h1}, 2); + + ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); +} + // Describes a binary rank-2 concatenation test. struct R2BinarySpec { int64 lhs_dim0; @@ -494,6 +536,63 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { {x_data.get(), y_data.get()}, ErrorSpec(1e-4)); } +// Test that the HLO optimization to replace a concat of a bradcasted scalar +// produces the correct result in rank 1. +XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { + auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); + auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); + auto y_literal = LiteralUtil::CreateR0(1.5f); + auto z_literal = LiteralUtil::CreateR0(5.5f); + auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto x = builder.Parameter(0, x_literal->shape(), "x"); + auto y = builder.Parameter(1, f32_scalar, "y"); + auto z = builder.Parameter(2, f32_scalar, "z"); + auto bcast = builder.Broadcast(y, {5}); + auto bcast2 = builder.Broadcast(z, {3}); + auto concat = builder.ConcatInDim({bcast, x}, /*dimension=*/0); + builder.ConcatInDim({concat, bcast2}, /*dimension=*/0); + + ComputeAndCompareR1( + &builder, + {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f}, + {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4)); +} + +// Test that the HLO optimization to replace a concat of a bradcasted scalar +// produces the correct result in rank 3 with both high and low padding in +// different dimensions. +XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { + auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); + Array3D x3d(3, 5, 7, 3.14f); + auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); + auto y_literal = LiteralUtil::CreateR0(1.5f); + auto z_literal = LiteralUtil::CreateR0(5.5f); + auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto x = builder.Parameter(0, x_literal->shape(), "x"); + auto y = builder.Parameter(1, f32_scalar, "y"); + auto z = builder.Parameter(2, f32_scalar, "y"); + auto y_bcast = builder.Broadcast(y, {1, 5, 7}); + auto z_bcast = builder.Broadcast(z, {4, 1, 7}); + auto concat = builder.ConcatInDim({y_bcast, x}, /*dimension=*/0); + builder.ConcatInDim({concat, z_bcast}, /*dimension=*/1); + Array3D y_bcast3d(1, 5, 7, 1.5f); + Array3D z_bcast3d(4, 1, 7, 5.5f); + auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0); + auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1); + + ComputeAndCompareR3(&builder, *concat1, + {x_data.get(), y_data.get(), z_data.get()}, + ErrorSpec(1e-4)); +} + INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest, ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0}, R2BinarySpec{1, 1, 1, 1, 1}, @@ -507,6 +606,7 @@ INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest, int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 58d52ac1168..1c065de8ba7 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -177,6 +178,7 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 9f8c3a9aeb7..6d379797250 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -36,8 +37,10 @@ namespace { class ConvertTest : public ClientLibraryTestBase { public: explicit ConvertTest(perftools::gputools::Platform* platform = nullptr) - : ClientLibraryTestBase(platform, - /*disabled_pass_names=*/{"algsimp", "inline"}) {} + : ClientLibraryTestBase(platform) { + mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); + mutable_debug_options()->add_xla_disable_hlo_passes("inline"); + } }; TEST_F(ConvertTest, ConvertR1S32ToR1S32) { @@ -195,6 +198,7 @@ TEST_F(ConvertTest, ConvertReshape) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 9f38dc4b365..0b09416a747 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -22,15 +22,15 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -43,8 +43,8 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); - ASSERT_MATCH(dimension_numbers_status.status().error_message(), - testing::ContainsRegex("input are not unique")); + ASSERT_THAT(dimension_numbers_status.status().error_message(), + ::testing::HasSubstr("input are not unique")); } // Tests the convolution operation with invalid weight dimension numbers. @@ -52,8 +52,8 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 2, 3, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); - ASSERT_MATCH(dimension_numbers_status.status().error_message(), - testing::ContainsRegex("weight are not unique")); + ASSERT_THAT(dimension_numbers_status.status().error_message(), + ::testing::HasSubstr("weight are not unique")); } XLA_TEST_F(ConvolutionDimensionNumbersTest, @@ -101,6 +101,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index ffbda89b948..ec19469fa66 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -48,7 +49,7 @@ class ConvolutionTest : public ClientLibraryTestBase { #if XLA_TEST_BACKEND_GPU // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial // convolution. So relax the absolute error threshold. - ErrorSpec error_spec_ = ErrorSpec(1e-3); + ErrorSpec error_spec_ = ErrorSpec(1e-2); #else ErrorSpec error_spec_ = ErrorSpec(1e-4); #endif @@ -256,8 +257,7 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { error_spec_); } -// TODO(b/32873825): implement 1D convolution on GPU. -XLA_TEST_F(ConvolutionTest, DISABLED_ON_GPU(Convolve1D_1x2x5_1x2x2_Valid)) { +XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { ComputationBuilder builder(client_, TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); @@ -284,9 +284,7 @@ XLA_TEST_F(ConvolutionTest, DISABLED_ON_GPU(Convolve1D_1x2x5_1x2x2_Valid)) { error_spec_); } -// TODO(b/32873825): implement 3D convolution on GPU. -XLA_TEST_F(ConvolutionTest, - DISABLED_ON_GPU(Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid)) { +XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { ComputationBuilder builder(client_, TestName()); std::vector input_dims = {1, 4, 2, 3, 3}; std::vector filter_dims = {2, 2, 2, 3, 3}; @@ -345,6 +343,7 @@ XLA_TEST_F(ConvolutionTest, int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index b599f9b95bc..b5afc2498da 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -23,11 +23,14 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -1273,11 +1276,100 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { ComputeAndCompareR4(&builder, {{{{13, 24, 130}}}}, {}, error_spec_); } +TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { + ComputationBuilder builder(client_, TestName()); + + auto gradients = builder.ConstantR3FromArray3D( + Array3D(1, 1, 1, /*value=*/1)); + auto weights = + builder.ConstantR3FromArray3D(Array3D({{{1, 10, 100}}})); + auto mirrored_weights = builder.Rev(weights, {2}); + builder.ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1}, + /*padding=*/{{1, 1}}); + ComputeAndCompareR3(&builder, {{{10}}}, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { + ComputationBuilder builder(client_, TestName()); + + auto activations = + builder.ConstantR3FromArray3D(Array3D({{{1, 2, 3, 4}}})); + auto gradients = + builder.ConstantR3FromArray3D(Array3D({{{100, 10, 1}}})); + auto forward_conv = builder.ConvGeneralDilated( + activations, gradients, + /*window_strides=*/{1}, + /*padding=*/{{2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, + ComputationBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/1)); + builder.Transpose(forward_conv, {0, 1, 2}); + + ComputeAndCompareR3(&builder, {{{13, 24, 130}}}, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { + ComputationBuilder builder(client_, TestName()); + + auto gradients_flat = LiteralUtil::CreateR1({1}); + auto gradients_literal = + LiteralUtil::Reshape(*gradients_flat, {1, 1, 1, 1, 1}) + .ConsumeValueOrDie(); + auto gradients = builder.ConstantLiteral(*gradients_literal); + + auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); + auto weights_literal = + LiteralUtil::Reshape(*weights_flat, {1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto weights = builder.ConstantLiteral(*weights_literal); + + auto expected_flat = LiteralUtil::CreateR1({10}); + auto expected_literal = + LiteralUtil::Reshape(*expected_flat, {1, 1, 1, 1, 1}).ConsumeValueOrDie(); + + auto mirrored_weights = builder.Rev(weights, {2, 3, 4}); + builder.ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1, 1}, + /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); +} + +TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { + ComputationBuilder builder(client_, TestName()); + + auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); + auto activations_literal = + LiteralUtil::Reshape(*activations_flat, {1, 1, 1, 1, 4}) + .ConsumeValueOrDie(); + auto activations = builder.ConstantLiteral(*activations_literal); + + auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); + auto gradients_literal = + LiteralUtil::Reshape(*gradients_flat, {1, 1, 1, 1, 3}) + .ConsumeValueOrDie(); + auto gradients = builder.ConstantLiteral(*gradients_literal); + + auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); + auto expected_literal = + LiteralUtil::Reshape(*expected_flat, {1, 1, 1, 1, 3}).ConsumeValueOrDie(); + + auto forward_conv = builder.ConvGeneralDilated( + activations, gradients, + /*window_strides=*/{1, 1, 1}, + /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, + ComputationBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/3)); + builder.Transpose(forward_conv, {0, 1, 2, 3, 4}); + ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 29e29505333..4c2413d0fe4 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -18,12 +18,14 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -44,11 +46,10 @@ class CopyOpTest : public HloTestBase { builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); - auto hlo_module = MakeUnique("test_module"); - hlo_module->AddEntryComputation(std::move(computation)); + auto module = CreateNewModule(); + module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(hlo_module), {}); + std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectEqual(literal, *result); } @@ -100,11 +101,11 @@ TEST_F(CopyOpTest, CopyParameterScalar) { auto computation = builder.Build(); - auto hlo_module = MakeUnique("test_module"); - hlo_module->AddEntryComputation(std::move(computation)); + auto module = CreateNewModule(); + module->AddEntryComputation(std::move(computation)); std::unique_ptr result = - ExecuteAndTransfer(std::move(hlo_module), {constant_device_base}); + ExecuteAndTransfer(std::move(module), {constant_device_base}); LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); } @@ -122,10 +123,9 @@ TEST_F(CopyOpTest, CopyConstantR2Twice) { auto computation = builder.Build(); - auto hlo_module = MakeUnique("test_module"); - hlo_module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(hlo_module), {}); + auto module = CreateNewModule(); + module->AddEntryComputation(std::move(computation)); + std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, *result, error_spec_); } @@ -148,10 +148,9 @@ TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { std::unique_ptr computation = builder.Build(); - auto hlo_module = MakeUnique("test_module"); - hlo_module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(hlo_module), {}); + auto module = CreateNewModule(); + module->AddEntryComputation(std::move(computation)); + std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); // The result of the computation has the default layout, which is the inverse // of the layout of the source literal. @@ -181,15 +180,10 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { std::unique_ptr computation = builder.Build(); - auto hlo_module = MakeUnique("test_module"); - auto config = MakeUnique(computation->ComputeProgramShape()); - *config->mutable_entry_computation_layout()->mutable_result_layout() = - ShapeLayout(ShapeUtil::MakeShapeWithLayout( - constant->shape().element_type(), - AsInt64Slice(constant->shape().dimensions()), {1, 2, 0})); - hlo_module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(hlo_module), std::move(config), {}); + auto module = CreateNewModule(); + module->AddEntryComputation(std::move(computation)); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); + std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR3EqualArray3D(a, *result); } @@ -220,18 +214,10 @@ void CopyOpTest::TestCopyConstantLayoutR4( std::unique_ptr computation = builder.Build(); - auto hlo_module = MakeUnique("test_module"); - auto config = MakeUnique(computation->ComputeProgramShape()); - *config->mutable_entry_computation_layout()->mutable_result_layout() = - ShapeLayout(ShapeUtil::MakeShapeWithLayout( - constant->shape().element_type(), - AsInt64Slice(constant->shape().dimensions()), ({ - std::vector p(permutation.rbegin(), permutation.rend()); - p; - }))); - hlo_module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(hlo_module), std::move(config), {}); + auto module = CreateNewModule(); + module->AddEntryComputation(std::move(computation)); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); + std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR4EqualArray4D(a, *result); } @@ -256,12 +242,29 @@ XLA_TEST_F(CopyOpTest, CopyConstantR4Layout0312_MultipleTilesPerLayer) { TestCopyConstantLayoutR4(2, 14, 5, 35, {0, 3, 1, 2}); } +using CopyOpClientTest = ClientLibraryTestBase; + +XLA_TEST_F(CopyOpClientTest, Copy0x0) { + Shape in_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {0, 1}); + Shape out_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {1, 0}); + auto empty = LiteralUtil::CreateFromShape(in_shape); + + ComputationBuilder builder(client_, TestName()); + auto param0 = builder.Parameter(0, in_shape, "input"); + auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); + + auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) + .ConsumeValueOrDie(); + LiteralTestUtil::ExpectEqual(*empty, *actual); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index dc54c9defec..32232acf6e3 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -29,23 +30,22 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" -extern "C" void __attribute__((visibility("default"))) -R0F32Add2(float* out, float** in) { + +extern "C" void TF_EXPORT R0F32Add2(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); *out = **in + 2.0f; } -extern "C" void __attribute__((visibility("default"))) -R2F32ReduceSum(float* out, float** in) { +extern "C" void TF_EXPORT R2F32ReduceSum(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; *out = array[0] + array[1] + array[2] + array[3]; } -extern "C" void __attribute__((visibility("default"))) -Add1ToValues(float* out, float** in) { +extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; out[0] = array[0] + 1; @@ -64,7 +64,7 @@ class CustomCallTest : public HloTestBase { }; XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { - auto hlo_module = MakeUnique("test_module"); + auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -72,15 +72,14 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2")); - hlo_module->AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); - std::unique_ptr result = - ExecuteAndTransfer(std::move(hlo_module), {}); + std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR0Near(44.0f, *result, error_spec_); } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { - auto hlo_module = MakeUnique("test_module"); + auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); Array2D array(2, 2); @@ -94,16 +93,15 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum")); - hlo_module->AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); - std::unique_ptr result = - ExecuteAndTransfer(std::move(hlo_module), {}); + std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR0Near(10.0f, *result, error_spec_); } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) { - auto hlo_module = MakeUnique("test_module"); + auto module = CreateNewModule(); auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( @@ -119,10 +117,9 @@ XLA_TEST_F(CustomCallTest, HloInstruction::CreateConcatenate(ShapeUtil::MakeShape(F32, {2, 2, 2}), {incremented, incremented_again}, 0)); - hlo_module->AddEntryComputation(b.Build()); + module->AddEntryComputation(b.Build()); - std::unique_ptr result = - ExecuteAndTransfer(std::move(hlo_module), {}); + std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR3EqualArray3D( Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result); } @@ -133,6 +130,7 @@ XLA_TEST_F(CustomCallTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 528efd2942b..074753bf6f8 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -20,16 +20,19 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { +using ::testing::HasSubstr; + class DeallocationTest : public ClientLibraryTestBase { protected: // Build and execute the given computation then verify the results can be @@ -50,7 +53,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { builder.ConstantR0(42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); - // A result can be transfered an arbitrary number of times. Add an extra + // A result can be transferred an arbitrary number of times. Add an extra // transfer here so we're not just testing that a second call to Transfer // fails. ASSERT_IS_OK(client_->Transfer(*global_data).status()); @@ -59,8 +62,8 @@ TEST_F(DeallocationTest, DeallocateScalar) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } TEST_F(DeallocationTest, DeallocateVector) { @@ -72,8 +75,8 @@ TEST_F(DeallocationTest, DeallocateVector) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } TEST_F(DeallocationTest, DeallocateEmptyVector) { @@ -85,8 +88,8 @@ TEST_F(DeallocationTest, DeallocateEmptyVector) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } XLA_TEST_F(DeallocationTest, DeallocateTuple) { @@ -99,8 +102,8 @@ XLA_TEST_F(DeallocationTest, DeallocateTuple) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { @@ -114,8 +117,8 @@ XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { @@ -130,8 +133,8 @@ XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } } // namespace @@ -139,6 +142,7 @@ XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 57a7c61b141..fcddffc1e13 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -21,9 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -34,6 +36,9 @@ limitations under the License. namespace xla { namespace { +using ::testing::ContainsRegex; +using ::testing::HasSubstr; + class DeconstructTupleTest : public ClientLibraryTestBase { protected: // Build and execute the given computation then verify the results can be @@ -61,11 +66,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { // Try copying the elements back and comparing it auto handles = result_status.ConsumeValueOrDie(); - std::vector copy(4); - ASSERT_IS_OK(client_->TransferInProcess(*handles[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); - ASSERT_IS_OK(client_->TransferInProcess(*handles[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + std::unique_ptr literal; + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[0])); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[1])); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { @@ -82,19 +87,20 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { auto handles1 = result_status1.ConsumeValueOrDie(); auto handles2 = result_status2.ConsumeValueOrDie(); - std::vector copy(4); - ASSERT_IS_OK(client_->TransferInProcess(*handles1[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); - ASSERT_IS_OK(client_->TransferInProcess(*handles1[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + std::unique_ptr literal; + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles1[0])); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles1[1])); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + handles1[0].reset(); handles1[1].reset(); - ASSERT_IS_OK(client_->TransferInProcess(*handles2[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); - ASSERT_IS_OK(client_->TransferInProcess(*handles2[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles2[0])); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles2[1])); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { @@ -112,15 +118,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { // the same as handle[3] and handle[1] should be the same as handle[2]. auto handles = result_status.ConsumeValueOrDie(); - std::vector copy(4); - ASSERT_IS_OK(client_->TransferInProcess(*handles[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); - ASSERT_IS_OK(client_->TransferInProcess(*handles[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); - ASSERT_IS_OK(client_->TransferInProcess(*handles[2], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); - ASSERT_IS_OK(client_->TransferInProcess(*handles[3], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + std::unique_ptr literal; + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[0])); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[1])); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[2])); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[3])); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { @@ -138,19 +144,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { // should not have been deallocated because of reference counting. global_data.reset(); - std::vector copy(4); - ASSERT_IS_OK(client_->TransferInProcess(*handles[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); - ASSERT_IS_OK(client_->TransferInProcess(*handles[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); - ASSERT_IS_OK(client_->TransferInProcess(*handles[2], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + std::unique_ptr literal; + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[0])); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[1])); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[2])); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); /// Try deallocating one of the repeated elements, then copy handles[0].reset(); - ASSERT_IS_OK(client_->TransferInProcess(*handles[2], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + TF_ASSIGN_OR_ASSERT_OK(literal, client_->Transfer(*handles[2])); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { @@ -160,8 +166,8 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { auto result_status = client_->DeconstructTuple(*global_data); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH(result_status.status().error_message(), - testing::ContainsRegex("global data handle .* is not a tuple")); + EXPECT_THAT(result_status.status().error_message(), + ContainsRegex("global data handle .* is not a tuple")); } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { @@ -189,9 +195,8 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { auto result_status = client_->DeconstructTuple(*global_data); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH( - result_status.status().error_message(), - testing::ContainsRegex("deconstructing nested tuples not yet supported")); + EXPECT_THAT(result_status.status().error_message(), + HasSubstr("deconstructing nested tuples not yet supported")); } } // namespace @@ -199,6 +204,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 197a8f86cb0..754eec1b1ed 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -54,6 +55,8 @@ class DotOperationTest : public ClientLibraryTestBase { template void TestNonsquareMatrixDot(bool lhs_row_major = false, bool rhs_row_major = false); + void TestMatrixDot(int M, int K, int N, bool lhs_row_major = false, + bool rhs_row_major = false); }; XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { @@ -65,6 +68,15 @@ XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { ComputeAndCompareR0(&builder, 0.0, {}, error_spec_); } +XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR2({{3.0, 4.0}}); + auto rhs = builder.ConstantR1({3.0, 4.0}); + auto result = builder.Dot(lhs, rhs); + + ComputeAndCompareR1(&builder, {25.0}, {}, error_spec_); +} + template void DotOperationTest::TestOneElementVectorDot() { ComputationBuilder builder(client_, TestName()); @@ -170,6 +182,84 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } +void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major, + bool rhs_row_major) { + std::unique_ptr> lhs_data = + MakeLinspaceArray2D(0.0, 1.0, M, K); + std::unique_ptr lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + *lhs_data, + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))); + auto lhs_handle = client_->TransferToServer(*lhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> rhs_data = + MakeLinspaceArray2D(0.0, 1.0, K, N); + std::unique_ptr rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + *rhs_data, + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))); + auto rhs_handle = client_->TransferToServer(*rhs_lit).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {M, K}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {K, N}), "rhs")); + + std::unique_ptr> expected = + ReferenceUtil::MatmulArray2D(*lhs_data, *rhs_data); + + ComputeAndCompareR2(&builder, *expected, + {lhs_handle.get(), rhs_handle.get()}, + ErrorSpec(0.3, 3e-3)); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTF) { + TestMatrixDot(12, 117, 7, true, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFT) { + TestMatrixDot(12, 117, 7, false, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTT) { + TestMatrixDot(12, 117, 7, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFF) { + TestMatrixDot(12, 117, 7, false, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTT) { + TestMatrixDot(270, 270, 520, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTF) { + TestMatrixDot(270, 270, 520, true, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFT) { + TestMatrixDot(270, 270, 520, false, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFF) { + TestMatrixDot(270, 270, 520, false, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTT) { + TestMatrixDot(269, 3, 520, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTF) { + TestMatrixDot(260, 3, 520, true, false); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFT) { + TestMatrixDot(260, 3, 520, false, true); +} + +XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) { + TestMatrixDot(260, 3, 520, false, false); +} + XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { constexpr bool kLhsRowMajor = false; constexpr bool kRhsRowMajor = false; @@ -277,9 +367,9 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) { std::vector out_slices; for (int i = 0; i < 4; ++i) { // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}); + auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2}); - auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}); + auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2}); auto out = builder.Dot(x_slice, y_slice); @@ -371,6 +461,7 @@ TEST_F(DotOperationTest, TransposeFolding) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendLayoutUtilFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 9fbd9d2f7fb..b7bb1792f3b 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -22,12 +22,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -57,6 +59,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { // Slice at dimension boundaries, but with sizes that cause indices to wrap. RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {6}, {4}, {6.0, 7.0, 0.0, 1.0}); + // Zero element slice. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {0}, {}); } template @@ -74,6 +78,12 @@ class DynamicSliceTest : public ClientLibraryTestBase { RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, {1, 1}, {3, 3}, {{5.0f, 6.0f, 4.0f}, {8.0f, 9.0f, 7.0f}, {2.0f, 3.0f, 1.0f}}); + // Zero element slice: 2x0. + RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {0, 0}, {2, 0}, {{}, {}}); + // Zero element slice: 0x2. + RunR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {0, 0}, {0, 2}, Array2D(0, 2)); } template @@ -108,7 +118,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { template void RunR1(const std::vector& input_values, const std::vector slice_starts, - const std::vector slice_sizes, + const std::vector& slice_sizes, const std::vector& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -126,7 +136,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { template void RunR2(const Array2D& input_values, const std::vector slice_starts, - const std::vector slice_sizes, + const std::vector& slice_sizes, const Array2D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -144,7 +154,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { template void RunR3(const Array3D& input_values, const std::vector slice_starts, - const std::vector slice_sizes, + const std::vector& slice_sizes, const Array3D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -199,6 +209,10 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {8.0, 9.0, 10.0}, {6}, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 9.0}); + // Zero-sized update. + RunR1({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, + {}, {2}, + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); // clang-format on } @@ -225,6 +239,11 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, {{10.0f, 11.0f}}, {2, 2}, {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 10.0f}}); + // Zero-sized update. + RunR2( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}, + {{}}, {2, 1}, + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); // clang-format on } @@ -474,19 +493,23 @@ void BM_DynamicSlice(int num_iters) { executors[device_ordinal], *start_indices_literal, buffer->mutable_buffer({}))); + std::unique_ptr executable = + client->Compile(computation, {&buffer->shape()}, ExecutableBuildOptions()) + .ConsumeValueOrDie(); + // Run some warm-up executions. - LocalExecuteOptions options; + ExecutableRunOptions options; options.set_allocator(&allocator); const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = client->ExecuteLocally(computation, {buffer.get()}, options); + auto result = executable->Run({buffer.get()}, options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = client->ExecuteLocally(computation, {buffer.get()}, options); + auto result = executable->Run({buffer.get()}, options); ASSERT_TRUE(result.ok()); } } @@ -497,6 +520,7 @@ BENCHMARK(BM_DynamicSlice); int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 8e300630858..80267e5459d 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -112,6 +113,7 @@ TEST_F(FloorCeilTest, R0Ceil) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index 2835038c90c..ee4e92505d9 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/test.h" @@ -45,6 +46,7 @@ TEST_F(FmaxSimpleTest, FmaxTenValues) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 7bddbfa894c..fa36381267e 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" using tensorflow::gtl::ArraySlice; @@ -74,7 +74,7 @@ class FusionTest : public HloTestBase { } auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -176,7 +176,7 @@ XLA_TEST_F(FusionTest, Test) { // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -204,7 +204,7 @@ XLA_TEST_F(FusionTest, Test) { HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kSelect, const10, add8, const9)); auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2})); + ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1})); // CreateFusionInstruction needs the `instructions_to_fuse` argument in // reverse topological order, so the first element in `instructions_to_fuse` // must be the root. @@ -224,7 +224,7 @@ XLA_TEST_F(FusionTest, Parameter) { // Build a computation and fuse part of it so the fusion instruction has an // operand parameter. auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -247,7 +247,7 @@ XLA_TEST_F(FusionTest, Parameter) { XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( @@ -271,7 +271,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto single_element_array = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( @@ -285,7 +285,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -300,7 +300,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( @@ -315,7 +315,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( @@ -329,7 +329,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -343,7 +343,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( @@ -357,7 +357,7 @@ XLA_TEST_F(FusionTest, Reshape__) { XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( @@ -372,7 +372,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -387,7 +387,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -402,7 +402,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -427,7 +427,7 @@ std::unique_ptr MakeReduceTestComputation() { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -446,7 +446,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -468,7 +468,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = MakeUnique(TestName()); + auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( @@ -574,6 +574,7 @@ XLA_TEST_F(FusionTest, Clamp2D) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc new file mode 100644 index 00000000000..f54fa2256e2 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" + +namespace xla { +namespace { + +class HloMetadataTest : public LocalClientTestBase { + protected: + HloMetadataTest() { + metadata_.set_op_type("add"); + metadata_.set_op_name("my_sum_op"); + } + + void BuildAddComputation(ComputationBuilder* builder) { + auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + builder->Add(x, y); + } + + OpMetadata metadata_; +}; + +TEST_F(HloMetadataTest, MetadataPropagation) { + ComputationBuilder builder(local_client_, "add"); + builder.SetOpMetadata(metadata_); + BuildAddComputation(&builder); + builder.ClearOpMetadata(); + + Shape argument_layout = ShapeUtil::MakeShape(F32, {}); + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr executable, + local_client_->Compile(builder.Build().ValueOrDie(), + {&argument_layout, &argument_layout}, + ExecutableBuildOptions())); + + auto instruction = executable->executable() + ->module() + .entry_computation() + ->root_instruction(); + EXPECT_EQ("add", instruction->metadata().op_type()); + EXPECT_EQ("my_sum_op", instruction->metadata().op_name()); +} + +TEST_F(HloMetadataTest, MetadataClearing) { + ComputationBuilder builder(local_client_, "add"); + builder.SetOpMetadata(metadata_); + // Some other pretend computation here. + builder.ClearOpMetadata(); + BuildAddComputation(&builder); + + Shape argument_layout = ShapeUtil::MakeShape(F32, {}); + auto executable_status = local_client_->Compile( + builder.Build().ValueOrDie(), {&argument_layout, &argument_layout}, + ExecutableBuildOptions()); + ASSERT_IS_OK(executable_status); + + std::unique_ptr executable = + executable_status.ConsumeValueOrDie(); + + auto instruction = executable->executable() + ->module() + .entry_computation() + ->root_instruction(); + // We expect these to be empty (no metadata set). + EXPECT_EQ("", instruction->metadata().op_type()); + EXPECT_EQ("", instruction->metadata().op_name()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 5b83171c4e2..5f7b7aa434e 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,6 +23,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace se = ::perftools::gputools; @@ -55,6 +56,8 @@ struct HloTestBase::EigenThreadPoolWrapper { HloTestBase::HloTestBase() : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) { + // TODO(b/62411181): get rid of this flag entirely when the usual debug flags + // are piped to all HLO tests. test_hlo_dumper_ = [](const HloModule& module, const string& label) { legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags(); if (flags->xla_hlo_test_generate_hlo_graph) { @@ -74,30 +77,21 @@ HloTestBase::~HloTestBase() { } } +std::unique_ptr HloTestBase::CreateNewModule() { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} + StatusOr HloTestBase::Execute( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments, Shape* result_shape) { - auto module_config = MakeUnique( - module->entry_computation()->ComputeProgramShape()); - return Execute(std::move(module), std::move(module_config), arguments, - result_shape); -} - -StatusOr HloTestBase::Execute( - std::unique_ptr hlo_module, - std::unique_ptr module_config, - tensorflow::gtl::ArraySlice arguments, - Shape* result_shape) { - VLOG(3) << "module_config layout " - << LayoutUtil::HumanString(module_config->entry_computation_layout() - .result_layout() - .layout()); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend_->compiler()->Compile(std::move(hlo_module), - std::move(module_config), test_hlo_dumper_, + backend_->compiler()->Compile(std::move(module), test_hlo_dumper_, backend_->default_stream_executor())); se::Stream stream(backend_->default_stream_executor()); @@ -111,9 +105,13 @@ StatusOr HloTestBase::Execute( backend_->eigen_intra_op_thread_pool_device()); HloExecutionProfile hlo_execution_profile; - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result, - executable->ExecuteOnStream(&run_options, arguments, - &hlo_execution_profile)); + ServiceExecutableRunOptions service_run_options( + run_options, backend_->StreamBorrower(), + backend_->inter_op_thread_pool()); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase result, + executable->ExecuteOnStream(&service_run_options, arguments, + &hlo_execution_profile)); TF_RET_CHECK(stream.BlockHostUntilDone()); allocations_.push_back(result); @@ -133,6 +131,7 @@ StatusOr HloTestBase::Execute( std::set added_opaques; for (auto element_buffer : element_buffers) { if (added_opaques.count(element_buffer.opaque()) == 0) { + CHECK(element_buffer.opaque() != nullptr); added_opaques.insert(element_buffer.opaque()); allocations_.push_back(element_buffer); } @@ -175,20 +174,26 @@ std::unique_ptr HloTestBase::ExecuteAndTransfer( return TransferFromDevice(result_shape, device_base); } -std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, - std::unique_ptr module_config, - tensorflow::gtl::ArraySlice arguments) { - Shape result_shape; - se::DeviceMemoryBase device_base = - Execute(std::move(module), std::move(module_config), arguments, - &result_shape) - .ValueOrDie(); - return TransferFromDevice(result_shape, device_base); -} - -string HloTestBase::TestName() const { +/* static */ +string HloTestBase::TestName() { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } +int ParseDebugOptionsFlagsAndRunTests(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + ::testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 91fc9b87cd5..98bc35ae528 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -45,6 +44,12 @@ class HloTestBase : public ::testing::Test { ~HloTestBase() override; + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. It's recommended to use this method to + // create all HloModules for tests. + std::unique_ptr CreateNewModule(); + // Executes the given module and returns a global data handle. StatusOr Execute( std::unique_ptr module, @@ -52,20 +57,11 @@ class HloTestBase : public ::testing::Test { arguments, Shape* result_shape); - // Variation of Execute which takes a custom module_config instead of creating - // a default one. - StatusOr Execute( - std::unique_ptr module, - std::unique_ptr module_config, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape); - // Transfers the given literal to the device and returns the data handle. perftools::gputools::DeviceMemoryBase TransferToDevice( const Literal& literal); - // Transfers the array refered to by the given handle from the device and + // Transfers the array referred to by the given handle from the device and // returns as a Literal. std::unique_ptr TransferFromDevice( const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); @@ -76,15 +72,35 @@ class HloTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); - // Variation of ExecuteAndTransfer which takes a custom module_config instead - // of creating a default one. - std::unique_ptr ExecuteAndTransfer( - std::unique_ptr module, - std::unique_ptr module_config, - tensorflow::gtl::ArraySlice - arguments); + // Convenience method to force the layout of a given parameter in a module. + // The layout of parameter number 'param_no' in the 'module' is set to + // 'layout'. + void ForceParameterLayout(HloModule* module, int64 param_no, + const Layout& layout) { + ASSERT_LT(param_no, + module->mutable_entry_computation_layout()->parameter_count()); + module->mutable_entry_computation_layout() + ->mutable_parameter_layout(param_no) + ->ResetLayout(layout); + } - string TestName() const; + // Convenience method to force the layout of the computation result in a + // module. The result layout of 'module' is set to 'layout'. + void ForceResultLayout(HloModule* module, const Layout& layout) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout); + } + + // Convenience method to clear the layout of the computation result in + // 'module'. + void ForceClearResultLayout(HloModule* module) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->Clear(); + } + + static string TestName(); std::unique_ptr backend_; @@ -99,6 +115,11 @@ class HloTestBase : public ::testing::Test { std::unique_ptr thread_pool_wrapper_; }; +// Convenience function that parses XLA debug options flags from argc/argv, +// calls InitGoogleTest and then calls and returns RUN_ALL_TESTS. Intended to be +// invoked from a test main() function. +int ParseDebugOptionsFlagsAndRunTests(int argc, char** argv); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/inprocess_service_test.cc b/tensorflow/compiler/xla/tests/inprocess_service_test.cc deleted file mode 100644 index ea0be07872f..00000000000 --- a/tensorflow/compiler/xla/tests/inprocess_service_test.cc +++ /dev/null @@ -1,205 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -// Tests which exercise the "InProcess" methods of xla::Client. The -// "InProcess" methods require that the client and server share the same -// process. -class InProcessServiceTest : public ClientLibraryTestBase { - protected: - std::unique_ptr ExecuteR2F32Constant( - std::initializer_list> values, - tensorflow::gtl::ArraySlice minor_to_major) { - ComputationBuilder builder(client_, TestName()); - builder.ConstantR2(values); - auto computation = builder.Build().ConsumeValueOrDie(); - CHECK_EQ(2, minor_to_major.size()); - - ExecutionOptions execution_options; - *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout( - F32, - /*dimensions=*/{static_cast(values.size()), - static_cast(values.begin()->size())}, - minor_to_major); - return client_->Execute(computation, {}, &execution_options) - .ConsumeValueOrDie(); - } - - ErrorSpec error_spec_{0.0001}; -}; - -XLA_TEST_F(InProcessServiceTest, TransferFromServer) { - ComputationBuilder builder(client_, TestName()); - builder.ConstantR1({1, 42, 5}); - auto computation = builder.Build().ConsumeValueOrDie(); - - auto handle = client_->Execute(computation, {}).ConsumeValueOrDie(); - - std::vector result(3, 0); - ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data())); - EXPECT_MATCH(result, testing::VectorMatcher({1, 42, 5})); -} - -XLA_TEST_F(InProcessServiceTest, TransferToServer) { - std::vector input{1.0f, 2.0f, -42.0f}; - Shape shape = ShapeUtil::MakeShape(F32, {3}); - auto data_handle = client_->TransferToServerInProcess(shape, input.data()) - .ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto param = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "param"); - builder.Add(param, param); - - ComputeAndCompareR1(&builder, {2.0f, 4.0f, -84.0f}, - {data_handle.get()}, error_spec_); -} - -// TODO(b/28506710): This test case seems not to test inprocess -// methods. -TEST_F(InProcessServiceTest, GetShape) { - ComputationBuilder builder(client_, TestName()); - builder.ConstantR1({1, 42, 5}); - auto computation = builder.Build().ConsumeValueOrDie(); - - auto handle = client_->Execute(computation, {}).ConsumeValueOrDie(); - - Shape shape = client_->GetShape(*handle).ConsumeValueOrDie(); - ASSERT_EQ(S32, shape.element_type()); - ASSERT_EQ(1, ShapeUtil::Rank(shape)); - ASSERT_EQ(3, shape.dimensions(0)); -} - -XLA_TEST_F(InProcessServiceTest, GetShapeOfClientSuppliedArrayRowMajor) { - std::vector input{1.0f, 2.0f, 3.0f, 4.0f}; - Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); - shape.clear_layout(); - *shape.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - auto handle = client_->TransferToServerInProcess(shape, input.data()) - .ConsumeValueOrDie(); - - Shape shape_returned = client_->GetShape(*handle).ConsumeValueOrDie(); - ASSERT_TRUE(ShapeUtil::Equal(shape, shape_returned)); -} - -XLA_TEST_F(InProcessServiceTest, GetShapeOfClientSuppliedArrayColMajor) { - std::vector input{1.0f, 2.0f, 3.0f, 4.0f}; - Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); - shape.clear_layout(); - *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - auto handle = client_->TransferToServerInProcess(shape, input.data()) - .ConsumeValueOrDie(); - - Shape shape_returned = client_->GetShape(*handle).ConsumeValueOrDie(); - ASSERT_TRUE(ShapeUtil::Equal(shape, shape_returned)); -} - -TEST_F(InProcessServiceTest, TransferToServerNoLayout) { - std::vector input{1.0f, 2.0f, -42.0f}; - Shape shape = ShapeUtil::MakeShape(F32, {3}); - shape.clear_layout(); - auto transfer_status = - client_->TransferToServerInProcess(shape, input.data()); - ASSERT_EQ(transfer_status.status().code(), - tensorflow::error::INVALID_ARGUMENT); -} - -XLA_TEST_F(InProcessServiceTest, ExecuteRowMajor) { - auto handle = - ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{1, 0}); - - std::vector result(4, 0.0); - Shape shape; - ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data())); - - EXPECT_MATCH(result, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); -} - -XLA_TEST_F(InProcessServiceTest, ExecuteColumnMajor) { - auto handle = - ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{0, 1}); - - std::vector result(4, 0); - Shape shape; - ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data())); - - EXPECT_MATCH(result, testing::VectorMatcher({1.0, 3.0, 2.0, 4.0})); -} - -XLA_TEST_F(InProcessServiceTest, ExecuteAndReuseDifferentLayouts) { - // Create arrays on the server which have different layouts. Verify the - // computation still produces the correct results. - auto handle_rowmaj = - ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{1, 0}); - - auto handle_colmaj = ExecuteR2F32Constant({{10.0, 20.0}, {30.0, 40.0}}, - /*minor_to_major=*/{0, 1}); - - ComputationBuilder builder(client_, TestName()); - auto param0 = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); - auto param1 = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "param1"); - builder.Add(param0, param1); - - Array2D expected({{11.0, 22.0}, {33.0, 44.0}}); - ComputeAndCompareR2(&builder, expected, - {handle_rowmaj.get(), handle_colmaj.get()}, - error_spec_); -} - -} // namespace -} // namespace xla - -int main(int argc, char** argv) { - std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); - xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_result) { - LOG(ERROR) << "\n" << usage; - return 2; - } - testing::InitGoogleTest(&argc, argv); - if (argc > 1) { - LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; - return 2; - } - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index f7bbc0f38bb..eb979ad189d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/io/path.h" @@ -76,11 +76,11 @@ string Hostname() { // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { +::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); if (ulhs != urhs) { - return testing::AssertionFailure() << tensorflow::strings::Printf( + return ::testing::AssertionFailure() << tensorflow::strings::Printf( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a", tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) @@ -90,33 +90,33 @@ testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { .c_str(), rhs, rhs); } - return testing::AssertionSuccess(); + return ::testing::AssertionSuccess(); } // Templated comparator that specializes for float equality comparison with the // bitwise helper above (this is the un-specialized fallback, to just use the // default gunit implementation). template -testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { +::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { if (lhs == rhs) { - return testing::AssertionSuccess(); + return ::testing::AssertionSuccess(); } ::testing::Message msg; msg << "Expected equality of these values:"; msg << "\n " << lhs; msg << "\n " << rhs; - return testing::AssertionFailure() << msg; + return ::testing::AssertionFailure() << msg; } // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> -testing::AssertionResult CompareEqual(float lhs, float rhs) { +::testing::AssertionResult CompareEqual(float lhs, float rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } template <> -testing::AssertionResult CompareEqual(double lhs, double rhs) { +::testing::AssertionResult CompareEqual(double lhs, double rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } @@ -130,7 +130,7 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = LiteralUtil::Get(expected, multi_index); NativeT actual_value = LiteralUtil::Get(actual, multi_index); - testing::AssertionResult result = + ::testing::AssertionResult result = CompareEqual(expected_value, actual_value); return result; // Defines implicit coersion to bool. } @@ -159,7 +159,7 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, EXPECT_FALSE(Equal(expected, actual)); } -/* static */ testing::AssertionResult LiteralTestUtil::Equal( +/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( const Literal& expected, const Literal& actual) { VLOG(1) << "expected: " << LiteralUtil::ToString(expected); VLOG(1) << "actual: " << LiteralUtil::ToString(actual); @@ -207,9 +207,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " << PrimitiveType_Name(expected.shape().element_type()); } - testing::AssertionResult result = testing::AssertionSuccess(); + ::testing::AssertionResult result = ::testing::AssertionSuccess(); if (!match) { - result = testing::AssertionFailure() + result = ::testing::AssertionFailure() << "expected: " << LiteralUtil::ToString(expected) << "\nactual: " << LiteralUtil::ToString(actual); VLOG(1) << result.message(); @@ -262,7 +262,7 @@ class NearComparator { max_abs_err_ = 0.0; *miscompares_.mutable_shape() = ShapeUtil::ChangeElementType(actual.shape(), PRED); - miscompares_.mutable_preds()->Resize( + miscompares_.mutable_preds()->resize( ShapeUtil::ElementsIn(miscompares_.shape()), false); multi_index_.resize(expected.shape().dimensions_size(), 0); @@ -314,7 +314,7 @@ class NearComparator { private: // EXPECTs that the two given scalar values are within the error bound. Keeps - // track of how many mismatches have occured to keep the size of the output + // track of how many mismatches have occurred to keep the size of the output // manageable. template bool ExpectValuesNear(NativeT expected, NativeT actual) { @@ -389,7 +389,7 @@ class NearComparator { tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(), now_usec, name.c_str())); TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), - filename, literal)); + filename, literal.ToProto())); LOG(ERROR) << "wrote to " << name << " file: " << filename; } @@ -421,12 +421,12 @@ class NearComparator { } // namespace -/* static */ testing::AssertionResult LiteralTestUtil::Near( +/* static */ ::testing::AssertionResult LiteralTestUtil::Near( const Literal& expected, const Literal& actual, const ErrorSpec& error) { NearComparator comparator(error); return comparator.ExpectNear(expected, actual) - ? testing::AssertionSuccess() - : testing::AssertionFailure() << "values were not near"; + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << "values were not near"; } /* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, @@ -435,14 +435,14 @@ class NearComparator { EXPECT_TRUE(Near(expected, actual, error)); } -/* static */ testing::AssertionResult LiteralTestUtil::NearTuple( +/* static */ ::testing::AssertionResult LiteralTestUtil::NearTuple( const Literal& expected, const Literal& actual, const ErrorSpec& error) { VLOG(1) << "expected: " << LiteralUtil::ToString(expected); VLOG(1) << "actual: " << LiteralUtil::ToString(actual); if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { - return testing::AssertionFailure() + return ::testing::AssertionFailure() << "tuples expected expected shape = " << expected.shape().ShortDebugString() << " actual shape = " << actual.shape().ShortDebugString(); @@ -469,7 +469,7 @@ class NearComparator { } } - return testing::AssertionSuccess(); + return ::testing::AssertionSuccess(); } /* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected, diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 85656a53e44..a8b07a2c5d1 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -18,15 +18,18 @@ limitations under the License. #include #include +#include #include #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" @@ -57,7 +60,7 @@ class LiteralTestUtil { // Asserts that the expected and actual literals are (bitwise) equal for all // elements in the literal. Also, asserts that the rank, dimensions sizes, and // primitive type are equal. - static testing::AssertionResult Equal( + static ::testing::AssertionResult Equal( const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; // Expects that expected and actual are Equal. @@ -101,7 +104,7 @@ class LiteralTestUtil { // Asserts that the expected and actual literals are within the given error // bound for all elements. Also, asserts that the rank, dimensions sizes, and // bounds are equivalent. Only supported for floating point values. - static testing::AssertionResult Near( + static ::testing::AssertionResult Near( const Literal& expected, const Literal& actual, const ErrorSpec& error) TF_MUST_USE_RESULT; @@ -147,7 +150,7 @@ class LiteralTestUtil { // tuples are within the given error bound. Tuples are matched recursively. // If the elements of the tuple are not floating-point types, the error spec // is ignored and exact equality is checked. - static testing::AssertionResult NearTuple( + static ::testing::AssertionResult NearTuple( const Literal& expected, const Literal& actual, const ErrorSpec& error) TF_MUST_USE_RESULT; @@ -170,6 +173,36 @@ class LiteralTestUtil { tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal); + // Creates a literal with the supplied shape, and uses the provided value + // generator to populate the literal's values. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation, and using the engine as entropy generator. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, typename E, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, T mean, T stddev); + private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); }; @@ -269,6 +302,40 @@ template ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error); } +template +/* static */ StatusOr> +LiteralTestUtil::CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + TF_RET_CHECK(shape.element_type() == type); + std::unique_ptr literal = LiteralUtil::CreateFromShape(shape); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + literal.get(), [&](tensorflow::gtl::ArraySlice indexes) { + return generator(indexes); + })); + return std::move(literal); +} + +template +/* static */ StatusOr> +LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, + T stddev) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + std::normal_distribution generator(mean, stddev); + return CreateRandomLiteral( + shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { + return generator(*engine); + }); +} + +template +/* static */ StatusOr> +LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { + std::minstd_rand0 engine; + return CreateRandomLiteral(shape, &engine, mean, stddev); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index fdec11c0e98..a94f45f73b7 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -83,9 +83,10 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]"; EXPECT_EQ(3, results.size()); for (const string& result : results) { - Literal literal; + LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, - &literal)); + &literal_proto)); + Literal literal(literal_proto); if (result.find("expected") != string::npos) { EXPECT_EQ("2", LiteralUtil::ToString(literal)); } else if (result.find("actual") != string::npos) { diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc index 5c7079267ba..89f9b8a701e 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -44,9 +44,8 @@ TEST_F(LocalClientAotTest, Constant) { OpaqueData opaque_data{100, 20, 3}; void* parameters[] = {&opaque_data}; float out = 0; - float tmp1 = 0; - float tmp2 = 0; - void* temporary_buffers[] = {&out, &tmp1, &tmp2, nullptr}; + char tmp[20] = {0}; + void* temporary_buffers[] = {&out, nullptr, &tmp}; SumAndDouble(&out, &run_options, parameters, temporary_buffers); EXPECT_EQ(out, 246.0f); diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index eed51bd6ad4..52816dc72cc 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -42,7 +42,7 @@ xla::Computation Doubler(xla::Client* client) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - auto client = xla::ClientLibrary::LocalClientOrDie(); + auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie(); xla::ComputationBuilder builder(client, "aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); @@ -74,7 +74,7 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); xla::Computation computation = builder.Build().ConsumeValueOrDie(); - xla::LocalClient::AheadOfTimeComputationInstance instance{ + xla::CompileOnlyClient::AotComputationInstance instance{ &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; xla::cpu::CpuAotCompilationOptions options( @@ -89,11 +89,10 @@ int main(int argc, char** argv) { // It's lame to hard-code the buffer assignments, but we need // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 0); - CHECK_EQ(result->buffer_sizes().size(), 4); + CHECK_EQ(result->buffer_sizes().size(), 3); CHECK_EQ(result->buffer_sizes()[0], sizeof(float)); // result buffer - CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // temp buffer - CHECK_EQ(result->buffer_sizes()[2], sizeof(float)); // temp buffer - CHECK_EQ(result->buffer_sizes()[3], -1); // param buffer + CHECK_EQ(result->buffer_sizes()[1], -1); // param buffer + CHECK_EQ(result->buffer_sizes()[2], 20); // temp buffer if (triple.isOSBinFormatELF()) { // Check the ELF magic. CHECK_EQ(result->object_file_data()[0], 0x7F); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 5c32ed88955..49207356e30 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -17,12 +17,19 @@ limitations under the License. #include +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -91,16 +98,34 @@ int64 TestAllocator::deallocation_count(int device_ordinal) const { return allocator_; } +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct LocalClientTestBase::EigenThreadPoolWrapper { + explicit EigenThreadPoolWrapper() + : pool(new tensorflow::thread::ThreadPool( + tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)), + wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), + device(new Eigen::ThreadPoolDevice(wrapper.get(), + wrapper->NumThreads())) {} + + std::unique_ptr pool; + std::unique_ptr wrapper; + std::unique_ptr device; +}; + LocalClientTestBase::LocalClientTestBase( perftools::gputools::Platform* platform) : local_client_( - ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()) { + ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()), + thread_pool_wrapper_(new EigenThreadPoolWrapper()) { stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform()) .ValueOrDie()[local_client_->default_device_ordinal()]; transfer_manager_ = TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie(); } +LocalClientTestBase::~LocalClientTestBase() {} + std::unique_ptr LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal) { return LiteralToScopedShapedBuffer(literal, @@ -166,55 +191,72 @@ LocalClientTestBase::ShapedBufferToScopedShapedBuffer( } *scoped_buffer->mutable_buffers() = shaped_buffer->buffers(); - TF_CHECK_OK( - scoped_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElement( - [&shaped_buffer](const ShapeIndex& index, bool is_leaf, - size_t* buffer_entry) -> ::tensorflow::Status { - if (is_leaf) { - *buffer_entry = - shaped_buffer->shape_index_to_buffer_entry().element( - index); - } - return tensorflow::Status::OK(); - })); + scoped_buffer->mutable_shape_index_to_buffer_entry()->ForEachMutableElement( + [&shaped_buffer](const ShapeIndex& index, size_t* buffer_entry) { + if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { + *buffer_entry = + shaped_buffer->shape_index_to_buffer_entry().element(index); + } + }); return scoped_buffer; } -LocalExecuteOptions LocalClientTestBase::DefaultLocalExecuteOptions() const { - return LocalExecuteOptions().set_allocator( - GetOrCreateAllocator(local_client_->platform())); +ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions() + const { + return ExecutableBuildOptions(); } -std::unique_ptr LocalClientTestBase::ExecuteLocally( +ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { + ExecutableRunOptions run_options; + run_options.set_inter_op_thread_pool( + local_client_->backend().inter_op_thread_pool()); + run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get()); + run_options.set_allocator(GetOrCreateAllocator(local_client_->platform())); + return run_options; +} + +std::unique_ptr LocalClientTestBase::ExecuteLocallyOrDie( const Computation& computation, tensorflow::gtl::ArraySlice arguments) { - return ExecuteLocally(computation, arguments, DefaultLocalExecuteOptions()); + return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), + DefaultExecutableRunOptions()) + .ConsumeValueOrDie(); } -std::unique_ptr LocalClientTestBase::ExecuteLocally( +std::unique_ptr LocalClientTestBase::ExecuteLocallyOrDie( const Computation& computation, tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options) { - return ShapedBufferToScopedShapedBuffer( - local_client_->ExecuteLocally(computation, arguments, options) - .ConsumeValueOrDie(), - options.allocator()); + const ExecutableBuildOptions& build_options, + const ExecutableRunOptions& run_options) { + return ExecuteLocally(computation, arguments, build_options, run_options) + .ConsumeValueOrDie(); } -void LocalClientTestBase::ExecuteLocally( +StatusOr> +LocalClientTestBase::ExecuteLocally( const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - ShapedBuffer* result) { - ExecuteLocally(computation, arguments, DefaultLocalExecuteOptions(), result); + tensorflow::gtl::ArraySlice arguments) { + return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), + DefaultExecutableRunOptions()); } -void LocalClientTestBase::ExecuteLocally( +StatusOr> +LocalClientTestBase::ExecuteLocally( const Computation& computation, tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options, ShapedBuffer* result) { - ASSERT_IS_OK( - local_client_->ExecuteLocally(computation, arguments, options, result)); + const ExecutableBuildOptions& build_options, + const ExecutableRunOptions& run_options) { + std::vector argument_layouts(arguments.size()); + for (int i = 0; i < arguments.size(); ++i) { + argument_layouts[i] = &arguments[i]->shape(); + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + local_client_->Compile(computation, argument_layouts, build_options)); + TF_ASSIGN_OR_RETURN(std::unique_ptr buffer, + executable->Run(arguments, run_options)); + return ShapedBufferToScopedShapedBuffer(std::move(buffer), + run_options.allocator()); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 62916d50e3c..e3c3bb46cf2 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -74,8 +74,10 @@ class TestAllocator : public StreamExecutorMemoryAllocator { // A base class for tests which exercise the LocalClient interface. class LocalClientTestBase : public ::testing::Test { protected: + struct EigenThreadPoolWrapper; explicit LocalClientTestBase( perftools::gputools::Platform* platform = nullptr); + virtual ~LocalClientTestBase(); static TestAllocator* GetOrCreateAllocator( perftools::gputools::Platform* platform); @@ -99,27 +101,30 @@ class LocalClientTestBase : public ::testing::Test { // Execute the given computation on the local client. With and without // options. - std::unique_ptr ExecuteLocally( + StatusOr> ExecuteLocally( const Computation& computation, tensorflow::gtl::ArraySlice arguments); - std::unique_ptr ExecuteLocally( + StatusOr> ExecuteLocally( const Computation& computation, tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options); + const ExecutableBuildOptions& build_options, + const ExecutableRunOptions& run_options); + + std::unique_ptr ExecuteLocallyOrDie( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments); + std::unique_ptr ExecuteLocallyOrDie( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutableBuildOptions& build_options, + const ExecutableRunOptions& run_options); + + // Returns a default set of execute options. + ExecutableBuildOptions DefaultExecutableBuildOptions() const; // Returns a default set of execute options, configured to use allocator_ // as the allocator. - LocalExecuteOptions DefaultLocalExecuteOptions() const; - - // Overloads which write result into the given buffer. - void ExecuteLocally( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - ShapedBuffer* result); - void ExecuteLocally( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const LocalExecuteOptions& options, ShapedBuffer* result); + ExecutableRunOptions DefaultExecutableRunOptions() const; // Convert a ShapedBuffer into a ScopedShaped buffer so that all buffers are // deallocated when the object is destructed. @@ -139,6 +144,8 @@ class LocalClientTestBase : public ::testing::Test { TransferManager* transfer_manager_; LocalClient* local_client_; + + std::unique_ptr thread_pool_wrapper_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index b520d89de3c..796f43ea4ed 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -47,6 +48,7 @@ TEST_F(LogTest, LogTenValues) { builder.Log(x); std::vector expected; + expected.reserve(input.size()); for (float f : input) { expected.push_back(std::log(f)); } @@ -59,6 +61,7 @@ TEST_F(LogTest, LogTenValues) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 014417a2057..e4dbd6864a3 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -22,18 +22,18 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -42,8 +42,10 @@ namespace { class MapTest : public ClientLibraryTestBase { public: explicit MapTest(perftools::gputools::Platform* platform = nullptr) - : ClientLibraryTestBase(platform, - /*disabled_pass_names=*/{"algsimp", "inline"}) {} + : ClientLibraryTestBase(platform) { + mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); + mutable_debug_options()->add_xla_disable_hlo_passes("inline"); + } // Creates a function that adds its scalar argument with the constant 1.0. // @@ -100,8 +102,8 @@ class MapTest : public ClientLibraryTestBase { // Creates a function that adds its scalar argument with the constant 1.0 and // then multiplies by the original element. // - // /---------------\ - // / \ + // /------------------| + // / | // x {R0F32} ----> (add) ----> (mul) // / // 1.0f ---------/ @@ -147,8 +149,8 @@ class MapTest : public ClientLibraryTestBase { // Creates a function that adds three scalar arguments // - // x {R0F32} ----\ - // \ + // x {R0F32} -------| + // | // y {R0F32} ----> (add) ---> (add) // / // z {R0F32} ---------------/ @@ -529,9 +531,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) { StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_MATCH(computation_status.status().ToString(), - testing::HasSubstr("error from: ErrorAdd: binary op with " - "different element types: f32[] and u16[]")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: binary op with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all @@ -568,12 +570,60 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { ErrorSpec(0.01f)); } +// Regression test for b/35786417, where the inliner would not notice the change +// of parameter order inside the map. +TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { + ComputationBuilder builder(client_, TestName()); + + auto sub_builder = builder.CreateSubBuilder("power"); + auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + sub_builder->Sub(y, x); // note that this is y - x, not x - y + auto sub_opposite = sub_builder->BuildAndNoteError(); + + std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); + std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); + builder.Map({param0, param1}, sub_opposite); + + ComputeAndCompareR0( + &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f)); +} + +// Regression test for b/35786417, where the inliner would CHECK-fail due to the +// mul inside the map having more parameters than the map does. +TEST_F(MapTestWithFullOpt, MapSquare) { + ComputationBuilder builder(client_, TestName()); + + auto sub_builder = builder.CreateSubBuilder("power"); + auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + sub_builder->Mul(x, x); + auto square = sub_builder->BuildAndNoteError(); + + std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + builder.Map({param0}, square); + + ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, + ErrorSpec(0.01f)); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 8aa40294406..51261f0ac1c 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -158,11 +159,65 @@ TEST_F(MatOpsSimpleTest, Max32x8Linspace) { TestLinspaceMax(32, 8); } TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); } +class MatOpsDotAddTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { + bool row_major = std::get<0>(GetParam()); + bool add_lhs = std::get<1>(GetParam()); + Array2D lhs({{1.0, 2.0}, {3.0, 4.0}}); + Array2D rhs({{10.0, 11.0}, {12.0, 13.0}}); + + auto minor_to_major = [](bool row_major) -> std::vector { + return {row_major ? 1 : 0, row_major ? 0 : 1}; + }; + + auto prim_type = primitive_util::NativeToPrimitiveType(); + Shape lhs_shape = + ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); + Shape rhs_shape = + ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); + + TF_ASSIGN_OR_ASSERT_OK( + auto lhs_handle, + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + TF_ASSIGN_OR_ASSERT_OK( + auto rhs_handle, + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + + ComputationBuilder builder(client_, TestName()); + auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); + auto result = builder.Dot(lhs_arg, rhs_arg); + Array2D expected; + if (add_lhs) { + result = builder.Add(result, lhs_arg); + expected = Array2D({{35, 39}, {81, 89}}); + } else { + result = builder.Add(result, rhs_arg); + expected = Array2D({{44, 48}, {90, 98}}); + } + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, + ErrorSpec(1e-6)); +} + +INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, + ::testing::Combine(::testing::Bool(), + ::testing::Bool())); + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index 2cd680399b3..4929e25c580 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -36,7 +37,7 @@ XLA_TEST_F(SliceTest, Slice2D) { ComputationBuilder builder(client_, "slice_2d"); auto original = builder.ConstantR2( {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}}); - builder.Slice(original, {2, 1}, {4, 3}); + builder.Slice(original, {2, 1}, {4, 3}, {1, 1}); Array2D expected({{8.0f, 9.0f}, {11.0f, 12.0f}}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); @@ -47,7 +48,7 @@ XLA_TEST_F(SliceTest, Slice3D) { Array3D array_3d( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}); auto original = builder.ConstantR3FromArray3D(array_3d); - builder.Slice(original, {0, 0, 1}, {2, 1, 2}); + builder.Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1}); Array3D expected_3d({{{2.0f}}, {{6.0f}}}); ComputeAndCompareR3(&builder, expected_3d, {}, ErrorSpec(0.000001)); @@ -58,6 +59,7 @@ XLA_TEST_F(SliceTest, Slice3D) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index f044c94b8d0..4922bbf21c4 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -468,6 +469,7 @@ XLA_TEST_F(PadTest, ReducePad) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 2f05576ceeb..3e1bfcd3090 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -163,7 +164,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) { auto computation = builder.Build().ConsumeValueOrDie(); auto execute_status = client_->Execute(computation, {data.get(), data.get()}, - /*output_layout=*/nullptr, + /*execution_options=*/nullptr, /*execution_profile=*/nullptr); ASSERT_EQ(execute_status.status().code(), tensorflow::error::FAILED_PRECONDITION); @@ -246,6 +247,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { } std::vector param_data; + param_data.reserve(param_data_owner.size()); for (const std::unique_ptr& data : param_data_owner) { param_data.push_back(data.get()); } @@ -326,7 +328,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { ComputationBuilder builder(client_, TestName()); auto input = builder.Parameter(0, original, "input"); // Use the slice operator to get an off-diagonal element. - builder.Slice(input, {0, 1}, {1, 2}); + builder.Slice(input, {0, 1}, {1, 2}, {1, 1}); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -341,6 +343,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 96393c41e80..b031725d8ab 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -18,9 +18,12 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -94,11 +97,51 @@ TEST_F(PredTest, ConstantR2Pred) { EXPECT_EQ(expected, ExecuteToString(&builder, {})); } +TEST_F(PredTest, AnyR1True) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({true, false}); + TF_ASSERT_OK(Any(a, &builder).status()); + ComputeAndCompareR0(&builder, true, {}); +} + +TEST_F(PredTest, AnyR1False) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({false, false}); + TF_ASSERT_OK(Any(a, &builder).status()); + ComputeAndCompareR0(&builder, false, {}); +} + +TEST_F(PredTest, AnyR1VacuouslyFalse) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + TF_ASSERT_OK(Any(a, &builder).status()); + ComputeAndCompareR0(&builder, false, {}); +} + +TEST_F(PredTest, AnyR2True) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({ + {false, false, false}, {false, false, false}, {false, false, true}, + }); + TF_ASSERT_OK(Any(a, &builder).status()); + ComputeAndCompareR0(&builder, true, {}); +} + +TEST_F(PredTest, AnyR2False) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({ + {false, false, false}, {false, false, false}, {false, false, false}, + }); + TF_ASSERT_OK(Any(a, &builder).status()); + ComputeAndCompareR0(&builder, false, {}); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 1b4b170dfd4..5117478bfd5 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -18,9 +18,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -38,6 +40,12 @@ class PrngTest : public ClientLibraryTestBase { template void UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims); void BernoulliTest(float p, tensorflow::gtl::ArraySlice dims); + + // Computes the χ² statistic of a sample of the discrete uniform distribution + // of the given range size. `expected_count` is the number of times each + // possible value is expected to be generated. Thus, the sample size is + // `range_size * expected_count`. + double UniformChiSquared(int32 range_size, int32 expected_count); }; template @@ -47,8 +55,9 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { builder.ConstantR0(a), builder.ConstantR0(b), ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); + SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); - EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions())); + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); LiteralUtil::EachCell(*actual, [=](tensorflow::gtl::ArraySlice, T value) { EXPECT_LE(a, value); @@ -68,7 +77,7 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options)); - EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions())); + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); int32 sum = 0; LiteralUtil::EachCell( *actual, [&sum](tensorflow::gtl::ArraySlice, uint32 value) { @@ -97,6 +106,57 @@ XLA_TEST_F(PrngTest, ZeroValuesR2) { UniformTest(0, 1, {0, 20}); } XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } +namespace { +template +T Square(T x) { + return x * x; +} +} // namespace + +double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) { + int32 sample_size = range_size * expected_count; + + ComputationBuilder builder(client_, TestName()); + builder.RngUniform(builder.ConstantR0(0), + builder.ConstantR0(range_size), + ShapeUtil::MakeShape(S32, {sample_size})); + + SetSeed(42); + auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + std::vector counts(range_size, 0); + LiteralUtil::EachCell( + *actual, [&counts](tensorflow::gtl::ArraySlice, int32 value) { + ++counts[value]; + }); + int64 sum = 0; + for (int32 i = 0; i < range_size; ++i) { + sum += Square(static_cast(counts[i] - expected_count)); + } + return static_cast(sum) / expected_count; +} + +// We only test distribution of uniform discrete PRNG as other types are based +// on it. +// These range sizes are arbitrary but include prime numbers, powers of 2, and +// other composite numbers. +// The level of significance in all these cases is 1/20. +// TODO(b/35723038): Use parametrized tests where possible. +XLA_TEST_F(PrngTest, Uniformity7) { + EXPECT_LT(UniformChiSquared(7, 256), 12.5916); +} +XLA_TEST_F(PrngTest, Uniformity61) { + EXPECT_LT(UniformChiSquared(61, 256), 79.0819); +} +XLA_TEST_F(PrngTest, Uniformity64) { + EXPECT_LT(UniformChiSquared(64, 256), 82.5287); +} +XLA_TEST_F(PrngTest, Uniformity108) { + EXPECT_LT(UniformChiSquared(108, 256), 132.144); +} +XLA_TEST_F(PrngTest, Uniformity256) { + EXPECT_LT(UniformChiSquared(256, 256), 293.248); +} + XLA_TEST_F(PrngTest, MapUsingRng) { // Build a x -> (x + U[0,1)) computation. auto build_sum_rng = [this](ComputationBuilder& builder) { @@ -135,7 +195,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { } } -// This tests demonstrates the global seeding behaviour. +// This tests demonstrates the global seeding behavior. // * If a seed is passed in via Execute (ExecuteAndTransfer) then the output is // fixed (i.e., there is a single output for a given seed); // * If no seed is passed in then the output of every call can be different; @@ -208,6 +268,7 @@ XLA_TEST_F(PrngTest, TenValuesN01) { builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), ShapeUtil::MakeShape(F32, {10})); + SetSeed(42); ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); // TODO(b/25995601): Test that resultant values are reasonable } @@ -217,6 +278,7 @@ XLA_TEST_F(PrngTest, TenValuesN01) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index eb7e63705b2..4a02567a1a2 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -46,6 +47,7 @@ TEST_F(QueryInferredShapeTest, OnePlusOneShape) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index f3d8da5c8c8..ff24177520e 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -61,7 +62,7 @@ namespace { class ReduceTest : public ClientLibraryTestBase { protected: ReduceTest() { - // Implementation note: layed out z >> y >> x by default. + // Implementation note: laid out z >> y >> x by default. // clang-format off literal_2d_ = LiteralUtil::CreateR2({ // x0 x1 x2 @@ -109,6 +110,41 @@ class ReduceTest : public ClientLibraryTestBase { ErrorSpec(0.001)); } + void RunR1ToR0PredTest(bool and_reduce, + tensorflow::gtl::ArraySlice input_data) { + const int element_count = input_data.size(); + ComputationBuilder builder(client_, TestName()); + const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); + auto input_par = builder.Parameter(0, input_shape, "input"); + auto pred_values = + builder.Eq(input_par, builder.ConstantR1(element_count, 1)); + ComputationDataHandle init_value; + Computation reduce; + if (and_reduce) { + init_value = builder.ConstantR0(true); + reduce = CreateScalarLogicalAndComputation(&builder); + } else { + init_value = builder.ConstantR0(false); + reduce = CreateScalarLogicalOrComputation(&builder); + } + builder.Reduce(pred_values, init_value, reduce, + /*dimensions_to_reduce=*/{0}); + + std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + bool expected = and_reduce; + for (bool item : input_data) { + if (and_reduce) { + expected = expected && item; + } else { + expected = expected || item; + } + } + ComputeAndCompareR0(&builder, expected, {input_global_data.get()}); + } + // Runs an R2 => R0 reduction test with the given number of (rows, cols). void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) { ComputationBuilder builder(client_, TestName()); @@ -176,9 +212,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); } XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); } XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); } XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); } -XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); } XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); } XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); } +XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); } XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); } XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); } XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); } @@ -186,6 +222,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); } XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) { RunR1ToR0Test(16 * 1024 + 1); } +XLA_TEST_F(ReduceTest, ReduceR1_64K_F32_To_R0) { RunR1ToR0Test(64 * 1024); } +XLA_TEST_F(ReduceTest, ReduceR1_1M_F32_To_R0) { RunR1ToR0Test(1024 * 1024); } +XLA_TEST_F(ReduceTest, ReduceR1_16M_F32_To_R0) { RunR1ToR0Test(4096 * 4096); } XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); } XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); } @@ -219,6 +258,40 @@ XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) { XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); } XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); } +// TODO(b/34969189): Invalid CAS generated on GPU. +XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceAllOnesR1_10_Pred)) { + constexpr int element_count = 10; + std::vector input(element_count, 1); + RunR1ToR0PredTest(/*and_reduce=*/true, input); +} + +// TODO(b/34969189): Invalid CAS generated on GPU. +XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { + constexpr int element_count = 10; + std::vector input(element_count); + for (int i = 0; i < element_count; ++i) { + input[i] = i % 2; + } + RunR1ToR0PredTest(/*and_reduce=*/true, input); +} + +// TODO(b/34969189): Invalid CAS generated on GPU. +XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceAllOnesR1_10_Pred)) { + constexpr int element_count = 10; + std::vector input(element_count, 1); + RunR1ToR0PredTest(/*and_reduce=*/false, input); +} + +// TODO(b/34969189): Invalid CAS generated on GPU. +XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceOnesAndZerosR1_10_Pred)) { + constexpr int element_count = 10; + std::vector input(element_count); + for (int i = 0; i < element_count; ++i) { + input[i] = i % 2; + } + RunR1ToR0PredTest(/*and_reduce=*/false, input); +} + XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { const int64 rows = 111, cols = 50; @@ -251,6 +324,72 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { ErrorSpec(0.01, 1e-4)); } +XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { + const int64 rows = 111, cols = 50; + + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); + auto input = builder.Parameter(0, input_shape, "input"); + auto zero = builder.ConstantR0(0.0); + auto log_ = builder.Log(input); + auto transpose = builder.Transpose(log_, {1, 0}); + builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1}); + + Array2D input_data(rows, cols); + input_data.FillRandom(3.14f, 0.04); + std::unique_ptr input_literal = + LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = + LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1})); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + std::vector expected; + for (int64 colno = 0; colno < cols; ++colno) { + float column_sum = 0; + for (int64 rowno = 0; rowno < rows; ++rowno) { + column_sum += log(input_data(rowno, colno)); + } + expected.push_back(column_sum); + } + ComputeAndCompareR1(&builder, expected, {input_global_data.get()}, + ErrorSpec(0.01, 1e-4)); +} + +XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { + const int64 rows = 111, cols = 50; + + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto zero = builder.ConstantR0(0.0); + auto log_ = builder.Log(input); + auto reshape = builder.Reshape(log_, {rows, cols}); + builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + Array3D input_data(rows, 2, cols / 2); + input_data.FillRandom(3.14f, 0.04); + std::unique_ptr input_literal = + LiteralUtil::CreateR3FromArray3D(input_data); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + std::vector expected; + for (int64 major = 0; major < 2; ++major) { + for (int64 colno = 0; colno < cols / 2; ++colno) { + float column_sum = 0; + for (int64 rowno = 0; rowno < rows; ++rowno) { + column_sum += log(input_data(rowno, major, colno)); + } + expected.push_back(column_sum); + } + } + ComputeAndCompareR1(&builder, expected, {input_global_data.get()}, + ErrorSpec(0.01, 1e-4)); +} + struct BoundsLayout { std::vector bounds; std::vector layout; @@ -490,6 +629,7 @@ INSTANTIATE_TEST_CASE_P( int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 149a75c8e10..ec7b47bc283 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -43,7 +44,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { public: ReduceWindowTest() : builder_(client_, TestName()) {} - void ReduceWindowAdd(ComputationDataHandle input, + void ReduceWindowAdd(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -52,7 +53,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { window_dimensions, window_strides, padding); } - void ReduceWindowMax(ComputationDataHandle input, + void ReduceWindowMax(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -61,7 +62,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { CreateScalarMax(), window_dimensions, window_strides, padding); } - void ReduceWindowMin(ComputationDataHandle input, + void ReduceWindowMin(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -182,6 +183,7 @@ TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); } + // TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes. TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmall) { Array4D input_array(2, 2, 4, 16); @@ -368,6 +370,16 @@ TEST_F(ReduceWindowTest, Add2x2In2x2Disjoint) { ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); } +TEST_F(ReduceWindowTest, Add1x2In2x2Same) { + Array2D input_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto input = builder_.ConstantR2FromArray2D(input_array); + ReduceWindowAdd(input, {1, 2}, {1, 1}, Padding::kSame); + Array2D expected({ + {3.0f, 2.0f}, {7.0f, 4.0f}, + }); + ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); +} + XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { Array3D input_array(2, 1, 2); input_array(0, 0, 0) = 1000; @@ -446,10 +458,16 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); - Array4D expected(1, 2, 1, 1); - expected(0, 0, 0, 0) = 6; - expected(0, 1, 0, 0) = 8; - ComputeAndCompareR4(&builder_, expected, {}, ErrorSpec(1e-3, 1e-3)); + const auto reduce_func = [](float arg1, float arg2) { + return std::min(arg1 + arg2, 8.0f); + }; + + auto expected = + ReferenceUtil::ReduceWindow4DGeneric(input_array, 3.0f, reduce_func, + /*window=*/{1, 1, 2, 1}, + /*stride=*/{1, 1, 1, 1}, padding); + + ComputeAndCompareR4(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3)); } } // namespace @@ -457,6 +475,7 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 802087b5086..7c6700feef8 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" @@ -152,6 +153,7 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index ce309eb7439..c9817bc23d8 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -61,6 +62,7 @@ TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 18e6e2d3f1d..ae7d07727b1 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -26,18 +26,18 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -68,6 +68,22 @@ XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); } +XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + a = builder.Neg(a); + auto reshape = + builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); + + ComputeAndCompareR1(&builder, {-1.0f}, {param0_data.get()}, + zero_error_spec_); +} + XLA_TEST_F(ReshapeTest, Trivial0x3) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); @@ -76,6 +92,24 @@ XLA_TEST_F(ReshapeTest, Trivial0x3) { ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-05-15 +// with an incorrect result rank. +XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = + LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0, 3}), "param0"); + auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); + + ComputeAndCompareR1(&builder, {}, {param0_data.get()}, + zero_error_spec_); +} + XLA_TEST_F(ReshapeTest, Trivial3x0) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); @@ -383,15 +417,15 @@ XLA_TEST_F(ReshapeTest, ToScalar) { XLA_TEST_F(ReshapeTest, BadDimensions) { ComputationBuilder b(client_, TestName()); b.Reshape(b.ConstantR1({1}), {}, {}); - EXPECT_MATCH(ExecuteToString(&b, {}), - testing::HasSubstr("dimensions not a permutation")); + EXPECT_THAT(ExecuteToString(&b, {}), + ::testing::HasSubstr("dimensions not a permutation")); } XLA_TEST_F(ReshapeTest, BadNewSizes) { ComputationBuilder b(client_, TestName()); b.Reshape(b.ConstantR1({1, 2}), {1}, {}); - EXPECT_MATCH(ExecuteToString(&b, {}), - testing::HasSubstr("mismatched element counts")); + EXPECT_THAT(ExecuteToString(&b, {}), + ::testing::HasSubstr("mismatched element counts")); } XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { @@ -796,6 +830,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 63dd4421fad..5ca9702380f 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -157,6 +158,7 @@ TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 5b734c0f400..05ce22fc359 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/packed_literal_reader.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -144,6 +145,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index 04a8bab0eb8..f0760241cdb 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -24,12 +24,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -148,6 +148,7 @@ TEST_F(RoundTripTransferTest, R4F32_Large) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 4d68ba46211..47a39ffbbc4 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -245,37 +247,183 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); - builder.Div(builder.ConstantR0(-5), builder.ConstantR0(2)); +struct DivS32Params { + int32 dividend; + int32 divisor; + int32 quotient; + int32 remainder; +}; - ComputeAndCompareR0(&builder, -2, {}); +void PrintTo(const DivS32Params& p, std::ostream* os) { + *os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", " + << p.remainder << "}"; } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) { - ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(-5), builder.ConstantR0(2)); +class DivS32Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; - ComputeAndCompareR0(&builder, -1, {}); +XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { + DivS32Params p = GetParam(); + ComputationBuilder builder(client_, TestName()); + builder.Div(builder.ConstantR0(p.dividend), + builder.ConstantR0(p.divisor)); + + ComputeAndCompareR0(&builder, p.quotient, {}); } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) { +XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { + DivS32Params p = GetParam(); ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(INT_MIN), - builder.ConstantR0(7919)); + builder.Rem(builder.ConstantR0(p.dividend), + builder.ConstantR0(p.divisor)); - ComputeAndCompareR0(&builder, -1309, {}); + ComputeAndCompareR0(&builder, p.remainder, {}); } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) { +XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { + DivS32Params p = GetParam(); ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(INT_MIN), - builder.ConstantR0(INT_MAX)); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividendd = + CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); + auto divisord = + CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); + builder.Div(dividend, divisor); - ComputeAndCompareR0(&builder, -1, {}); + ComputeAndCompareR0(&builder, p.quotient, + {dividendd.get(), divisord.get()}); } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) { +XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { + DivS32Params p = GetParam(); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividendd = + CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); + auto divisord = + CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); + builder.Rem(dividend, divisor); + + ComputeAndCompareR0(&builder, p.remainder, + {dividendd.get(), divisord.get()}); +} + +INSTANTIATE_TEST_CASE_P( + DivS32Test_Instantiation, DivS32Test, + ::testing::Values( + // Positive divisors. + DivS32Params{5, 2, 2, 1}, // + DivS32Params{-5, 2, -2, -1}, // + DivS32Params{17, 3, 5, 2}, // + DivS32Params{-17, 3, -5, -2}, // + // Negative divisors. + DivS32Params{5, -2, -2, 1}, // + DivS32Params{-5, -2, 2, -1}, // + DivS32Params{17, -3, -5, 2}, // + DivS32Params{-17, -3, 5, -2}, // + // Large positive divisors. + DivS32Params{INT32_MIN, 7919, -271181, -1309}, // + DivS32Params{INT32_MIN, INT32_MAX, -1, -1}, // + DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0}, // + DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2}, // + DivS32Params{INT32_MIN, 0x40000000, -2, 0}, // + DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff}, // + // Large negative divisors. + DivS32Params{INT32_MIN, INT32_MIN, 1, 0}, // + DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1}, // + DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1}, // + DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX}, // + DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0}, // + DivS32Params{INT32_MIN, -0x40000000, 2, 0}, // + DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff})); + +TEST_F(ScalarComputationsTest, DivU32s) { + // clang-format off + // Some interesting values to test. + std::vector vals = { + 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; + // clang-format on + + Computation div_computation; + { + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle dividend = + builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + ComputationDataHandle divisor = + builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); + builder.Div(dividend, divisor); + TF_ASSIGN_OR_ASSERT_OK(div_computation, builder.Build()); + } + + for (uint32 divisor : vals) { + if (divisor != 0) { + for (uint32 dividend : vals) { + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); + TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, + client_->TransferToServer(*dividend_literal)); + TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, + client_->TransferToServer(*divisor_literal)); + auto actual_literal = + client_ + ->ExecuteAndTransfer(div_computation, + {dividend_data.get(), divisor_data.get()}, + &execution_options_) + .ConsumeValueOrDie(); + auto expected_literal = + LiteralUtil::CreateR0(dividend / divisor); + LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + } + } + } +} + +TEST_F(ScalarComputationsTest, RemU32s) { + // clang-format off + // Some interesting values to test. + std::vector vals = { + 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; + // clang-format on + + Computation rem_computation; + { + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle dividend = + builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + ComputationDataHandle divisor = + builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); + builder.Rem(dividend, divisor); + TF_ASSIGN_OR_ASSERT_OK(rem_computation, builder.Build()); + } + + for (uint32 divisor : vals) { + if (divisor != 0) { + for (uint32 dividend : vals) { + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); + TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, + client_->TransferToServer(*dividend_literal)); + TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, + client_->TransferToServer(*divisor_literal)); + auto actual_literal = + client_ + ->ExecuteAndTransfer(rem_computation, + {dividend_data.get(), divisor_data.get()}, + &execution_options_) + .ConsumeValueOrDie(); + auto expected_literal = + LiteralUtil::CreateR0(dividend % divisor); + LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + } + } + } +} + +TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { ComputationBuilder builder(client_, TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); @@ -295,6 +443,13 @@ XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { ComputeAndCompareR0(&builder, 0x7FFFFFFF, {}); } +XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { + ComputationBuilder builder(client_, TestName()); + builder.Rem(builder.ConstantR0(11), builder.ConstantR0(3)); + + ComputeAndCompareR0(&builder, 2, {}); +} + TEST_F(ScalarComputationsTest, LogicalAnd) { for (bool x : {false, true}) { for (bool y : {false, true}) { @@ -626,6 +781,7 @@ TEST_F(ScalarComputationsTest, SqrtF320) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index fb1effc8c46..36110da2478 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -379,6 +380,7 @@ XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 5ec9ac95fae..5eb4fee8ed2 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -260,6 +261,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc index e15d744d953..25bb915be56 100644 --- a/tensorflow/compiler/xla/tests/set_return_value_test.cc +++ b/tensorflow/compiler/xla/tests/set_return_value_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/core/status.h" @@ -100,6 +101,7 @@ TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index d63582fb98a..70345c300cc 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -37,13 +38,14 @@ class SliceTest : public ClientLibraryTestBase { template void RunSliceTenToTwo() { std::vector constant; + constant.reserve(10); for (int i = 0; i < 10; ++i) { constant.push_back(static_cast(i)); } ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1(constant); - builder.Slice(original, {2}, {4}); + builder.Slice(original, {2}, {4}, {1}); const std::vector expected = {static_cast(2), static_cast(3)}; @@ -54,7 +56,7 @@ class SliceTest : public ClientLibraryTestBase { XLA_TEST_F(SliceTest, SliceZeroToZeroF32) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1({}); - builder.Slice(original, {0}, {0}); + builder.Slice(original, {0}, {0}, {1}); ComputeAndCompareR1(&builder, {}, {}); } @@ -63,7 +65,7 @@ XLA_TEST_F(SliceTest, SliceTenToZeroF32) { ComputationBuilder builder(client_, TestName()); std::vector constant(10, 0.3); auto original = builder.ConstantR1(constant); - builder.Slice(original, {7}, {7}); + builder.Slice(original, {7}, {7}, {1}); ComputeAndCompareR1(&builder, {}, {}); } @@ -86,7 +88,7 @@ TEST_F(SliceTest, SliceTenToTen) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1(values); - builder.Slice(original, {0}, {10}); + builder.Slice(original, {0}, {10}, {1}); ComputeAndCompareR1(&builder, values, {}, ErrorSpec(0.000001)); } @@ -97,7 +99,7 @@ TEST_F(SliceTest, SliceLastFourOf1024) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1(values); - builder.Slice(original, {1024 - 4}, {1024}); + builder.Slice(original, {1024 - 4}, {1024}, {1}); const std::vector expected = {1020, 1021, 1022, 1023}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.000001)); @@ -111,7 +113,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1(values); - builder.Slice(original, {7}, {7 + 1024}); + builder.Slice(original, {7}, {7 + 1024}, {1}); std::vector expected(1024); std::iota(values.begin(), values.end(), 7.0); @@ -121,7 +123,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) { XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(0, 0)); - builder.Slice(original, {0, 0}, {0, 0}); + builder.Slice(original, {0, 0}, {0, 0}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(0, 0), {}); } @@ -129,7 +131,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(0, 20)); - builder.Slice(original, {0, 15}, {0, 20}); + builder.Slice(original, {0, 15}, {0, 20}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(0, 5), {}); } @@ -137,7 +139,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(3, 0)); - builder.Slice(original, {1, 0}, {3, 0}); + builder.Slice(original, {1, 0}, {3, 0}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(2, 0), {}); } @@ -152,7 +154,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {128, 128}, {256, 256}); + builder.Slice(original, {128, 128}, {256, 256}, {1, 1}); Array2D expected(128, 128); for (int row = 0; row < 128; ++row) { @@ -170,7 +172,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {0, 3072}, {1, 4096}); + builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1}); Array2D expected(1, 1024); std::iota(expected.data(), expected.data() + 1024, 3072.0); @@ -191,7 +193,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { } ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {0, 0}, {16, 2}); + builder.Slice(original, {0, 0}, {16, 2}, {1, 1}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); } @@ -203,7 +205,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}); ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR4FromArray4D(values); - builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}); + builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } @@ -212,6 +214,7 @@ struct R2Spec { int64 input_dim1; std::array slice_starts; std::array slice_limits; + std::array slice_strides; Layout layout; }; @@ -227,7 +230,7 @@ TEST_P(SliceR2Test, DoIt) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D(input); - builder.Slice(a, spec.slice_starts, spec.slice_limits); + builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); std::unique_ptr> expected = ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits); @@ -238,19 +241,23 @@ TEST_P(SliceR2Test, DoIt) { INSTANTIATE_TEST_CASE_P( SliceR2TestInstantiation, SliceR2Test, ::testing::Values( - R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({0, 1})}, - R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({0, 1})}, - R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {256, 400, {{0, 300}}, {{256, 400}}, + R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, + LayoutUtil::MakeLayout({0, 1})}, + R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {500, 400, {{111, 123}}, {{300, 257}}, + R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, + LayoutUtil::MakeLayout({0, 1})}, + R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {500, 400, {{111, 123}}, {{300, 400}}, + R2Spec {256, 400, {{0, 300}}, {{256, 400}}, {{1, 1}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {384, 512, {{128, 256}}, {{256, 384}}, + R2Spec {500, 400, {{111, 123}}, {{300, 257}}, {{1, 1}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {357, 512, {{111, 256}}, {{301, 384}}, + R2Spec {500, 400, {{111, 123}}, {{300, 400}}, {{1, 1}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}}, + LayoutUtil::MakeLayout({1, 0})}, + R2Spec {357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}}, LayoutUtil::MakeLayout({1, 0})} ) ); @@ -261,6 +268,7 @@ INSTANTIATE_TEST_CASE_P( int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index 79f251bbc48..e4951c42010 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -187,6 +188,7 @@ TEST_F(TransposeTest, TransposeConstant021_MultipleTilesPerLayer) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index cea9316a6d6..6309e712973 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -399,6 +400,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index fdbaa0d1786..61110d5b4cd 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -163,6 +164,7 @@ TEST_F(UnaryOpTest, SignAbsTestR2) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 7f3d7d9cb4c..26a08953b15 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -219,6 +220,7 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index d9fc1e1e8f5..efde45375fd 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -41,8 +42,10 @@ namespace { class VecOpsSimpleTest : public ClientLibraryTestBase { public: explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr) - : ClientLibraryTestBase(platform, - /*disabled_pass_names=*/{"algsimp", "inline"}) {} + : ClientLibraryTestBase(platform) { + mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); + mutable_debug_options()->add_xla_disable_hlo_passes("inline"); + } ErrorSpec error_spec_{0.0001}; }; @@ -64,6 +67,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) { for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) { ComputationBuilder builder(client_, TestName()); std::vector exponents; + exponents.reserve(count); for (int i = 0; i < count; ++i) { exponents.push_back(i / static_cast(count)); } @@ -71,6 +75,7 @@ TEST_F(VecOpsSimpleTest, ExpManyValues) { auto exp = builder.Exp(x); std::vector expected; + expected.reserve(exponents.size()); for (float exponent : exponents) { expected.push_back(std::exp(exponent)); } @@ -155,6 +160,35 @@ TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({0.0, -0.0}); + auto exp = builder.SqrtF32(x); + + ComputeAndCompareR1(&builder, {0, 0}, {}, error_spec_); +} + +XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); + auto exp = builder.SqrtF32(x); + + std::vector expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { + ComputationBuilder builder(client_, TestName()); + auto x = + builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); + auto exp = builder.Pow(x, builder.ConstantR0(-.5f)); + + std::vector expected = {.25, 1, .03125, 2.5, + 2.23607, .009000, .900025}; + + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { ComputationBuilder builder(client_, TestName()); auto add = CreateScalarAddComputation(F32, &builder); @@ -408,6 +442,7 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index e6bbed671ff..5f917797744 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -23,9 +23,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -247,6 +249,291 @@ TEST_F(WhileTest, WhileWithTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +// Tests two while nodes when the result type T is a Tuple and the second +// while node uses the result of the first while node which is used in two +// nodes. +// tuple> w0(0, vector(10, 0.0f)); +// w0 = while (get<0>(w0) < c1) { +// get<0>(w0) = get<0>(w0) + 1; +// get<1>(w0) = get<1>(w0) + vector(10, 1.0f); +// } +// tuple> w1(get<0>(w0), get<1>(w0)); +// w1 = while (get<0>(w1) < c2) { +// get<0>(w1) = get<0>(w1) + 1; +// get<1>(w1) = get<1>(w1) + vector(10, 1.0f); +// } +// result = get<1>(w0) + get<1>(w1) +TEST_F(WhileTest, TwoWhileWithTupleResult) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {10})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + const int c1 = 5; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c1)); + TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + } + + Computation condition2; + const int c2 = 7; + { + ComputationBuilder builder(client_, "condition2"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c2)); + TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and add a constant vector of 1.0f to + // the weight variable, both of which are tuple elements. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto input = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, input); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + } + + Computation body2; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto input = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, input); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + TF_ASSIGN_OR_ASSERT_OK(body2, builder.Build()); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); + auto while1 = builder.While(condition, body, init); + + auto while2 = builder.While(condition2, body2, while1); + + auto while_result1 = builder.GetTupleElement(while1, 1); + auto while_result2 = builder.GetTupleElement(while2, 1); + VLOG(2) << "while_result2 = " + << ShapeUtil::HumanString( + *builder.GetShape(while_result2).ConsumeValueOrDie()); + auto result = builder.Add(while_result1, while_result2); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + const float sum = c1 + c2; + std::vector expected(10, sum); + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Test while nodes that share the while body computation. +TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {10})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + const int c1 = 5; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c1)); + TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + } + + Computation condition2; + const int c2 = 7; + { + ComputationBuilder builder(client_, "condition2"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c2)); + TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and add a constant vector of 1.0f to + // the weight variable, both of which are tuple elements. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto input = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, input); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); + auto while1 = builder.While(condition, body, init); + + auto while2 = builder.While(condition2, body, while1); + + auto while_result1 = builder.GetTupleElement(while1, 1); + auto while_result2 = builder.GetTupleElement(while2, 1); + VLOG(2) << "while_result2 = " + << ShapeUtil::HumanString( + *builder.GetShape(while_result2).ConsumeValueOrDie()); + auto result = builder.Add(while_result1, while_result2); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + const float sum = c1 + c2; + std::vector expected(10, sum); + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Test while nodes that share the while body computation. +// TODO(b/37245345): Fails on GPU backend. +TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {10})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + const int c1 = 5; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c1)); + TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + } + + Computation condition2; + const int c2 = 7; + { + ComputationBuilder builder(client_, "condition2"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c2)); + TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and add a constant vector of 1.0f to + // the weight variable, both of which are tuple elements. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto input = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, input); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); + auto while1 = builder.While(condition, body, init); + auto while2 = builder.While(condition2, body, init); + + auto while_result1 = builder.GetTupleElement(while1, 1); + auto while_result2 = builder.GetTupleElement(while2, 1); + VLOG(2) << "while_result2 = " + << ShapeUtil::HumanString( + *builder.GetShape(while_result2).ConsumeValueOrDie()); + auto result = builder.Add(while_result1, while_result2); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + const float sum = c1 + c2; + std::vector expected(10, sum); + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// WhileTest that uses DynamicUpdateSlice instruction in body computation. +// Loop state tuple element 1 has as its single user operand(0) of +// DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU. +XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {10})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(5), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and add a constant vector of 1.0f to + // the weight variable, both of which are tuple elements. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + // TupleElement 0 + auto iteration = builder.GetTupleElement(prev, 0); + auto out0 = builder.Add(iteration, builder.ConstantR0(1)); + // TupleElement 1 + auto input = builder.GetTupleElement(prev, 1); + // Update. + auto update = builder.ConvertElementType(builder.Broadcast(out0, {2}), F32); + // Starts = iteration * 2; + auto starts = builder.Reshape( + builder.Mul(iteration, builder.ConstantR0(2)), {1}); + // UpdateSlice. + auto out1 = builder.DynamicUpdateSlice(input, update, starts); + + auto result = builder.Tuple({out0, out1}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR1( + {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); + auto expected = + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); +} + // Tests a while node when the result type T is a vector of S32. // // int32 result = (0, 0, 0, 0, 0, 0); @@ -254,7 +541,8 @@ TEST_F(WhileTest, WhileWithTupleResult) { // result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]); // } // -// This test misuses a vector to represent a pair: +// This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a +// pair: // ((iteration, (random vector))). // // Note: this test currently only tests generating random values within a loop. @@ -268,7 +556,8 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { auto build_condition = [this, v6s32](int count) { ComputationBuilder builder(client_, TestName()); auto prev = builder.Reshape( - builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}), {0}, {}); + builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, + {}); builder.Gt(builder.ConstantR0(count), prev); return builder.Build().ConsumeValueOrDie(); }; @@ -308,6 +597,74 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { } } +// Tests nested while loops. +// +// int32 result = 0; +// while (result < 30) { +// int i = 0; +// while (i < 7) { +// result = result + 2; +// i = i + 1; +// } +// } +XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { + auto outer_result_shape = ShapeUtil::MakeShape(S32, {}); + auto inner_result_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); + + Computation inner_condition; + { + ComputationBuilder builder(client_, "inner_condition"); + auto params = builder.Parameter(0, inner_result_shape, "prev"); + auto i = builder.GetTupleElement(params, 0); + builder.Lt(i, builder.ConstantR0(7)); + inner_condition = builder.Build().ConsumeValueOrDie(); + } + + // Creates a computation for the outer loop condition: + // repeat while result < 30. + Computation outer_condition; + { + ComputationBuilder builder(client_, "outer_condition"); + auto prev = builder.Parameter(0, outer_result_shape, "prev"); + builder.Lt(prev, builder.ConstantR0(30)); + outer_condition = builder.Build().ConsumeValueOrDie(); + } + + // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to + // `result`. + Computation inner_body; + { + ComputationBuilder builder(client_, "inner_body"); + auto params = builder.Parameter(0, inner_result_shape, "prev"); + auto i = builder.GetTupleElement(params, 0); + auto result = builder.GetTupleElement(params, 1); + i = builder.Add(builder.ConstantR0(1), i); + result = builder.Add(builder.ConstantR0(2), result); + auto output = builder.Tuple({i, result}); + inner_body = builder.Build().ConsumeValueOrDie(); + } + + // Creates a computation for the outer loop: run the inner loop with i = 0. + Computation outer_body; + { + ComputationBuilder builder(client_, "outer_body"); + auto prev = builder.Parameter(0, outer_result_shape, "prev"); + auto init = builder.Tuple({builder.ConstantR0(0), prev}); + auto result = builder.While(inner_condition, inner_body, init); + auto output = builder.GetTupleElement(result, 1); + outer_body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.ConstantR0(0); + auto result = builder.While(outer_condition, outer_body, init); + auto shape = builder.GetShape(result).ConsumeValueOrDie(); + + ComputeAndCompareR0(&builder, 42, {}); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); @@ -354,19 +711,23 @@ void BM_WhileLoop(int num_iters) { builder.While(condition, body, init); auto computation = builder.Build().ConsumeValueOrDie(); + std::unique_ptr executable = + client->Compile(computation, {}, ExecutableBuildOptions()) + .ConsumeValueOrDie(); + // Run some warm-up executions. - LocalExecuteOptions options; + ExecutableRunOptions options; options.set_allocator(&allocator); const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = client->ExecuteLocally(computation, {}, options); + auto result = executable->Run({}, options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = client->ExecuteLocally(computation, {}, options); + auto result = executable->Run({}, options); ASSERT_TRUE(result.ok()); } } @@ -381,6 +742,7 @@ BENCHMARK(BM_WhileLoop); int main(int argc, char** argv) { std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index 3cfbb2c7fbf..e45e5291c9b 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 94d0f2646b1..a167d80f73b 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 545bd22da91..7375493f430 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 9dce4d13bb0..177ae4ea036 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 46eab7f02bb..535e5b605b4 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -153,6 +153,7 @@ cc_binary( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], @@ -176,6 +177,24 @@ cc_binary( ], ) +cc_binary( + name = "dumped_computation_to_tf_graphdef", + srcs = ["dumped_computation_to_tf_graphdef.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 4c242abc9b7..8d7f7fd1237 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -81,6 +81,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { client->GetComputationShape(computation).ConsumeValueOrDie(); std::vector layouts; + layouts.reserve(program_shape->parameters_size()); for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 8b96e134897..2a3a8803283 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,7 +35,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -50,23 +51,37 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } Computation computation = computation_status.ConsumeValueOrDie(); - std::unique_ptr program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); + if (compile) { + std::unique_ptr program_shape = + client->GetComputationShape(computation).ConsumeValueOrDie(); - std::vector layouts; - for (int i = 0; i < program_shape->parameters_size(); ++i) { - layouts.push_back(&program_shape->parameters(i)); + std::vector layouts; + layouts.reserve(program_shape->parameters_size()); + for (int i = 0; i < program_shape->parameters_size(); ++i) { + layouts.push_back(&program_shape->parameters(i)); + } + StatusOr> executable = + local_service->CompileExecutable( + computation.handle(), layouts, &program_shape->result(), + /*device_ordinal=*/0, /*has_hybrid_result=*/true); + + const HloModule& module = executable.ValueOrDie()->module(); + + fprintf(stdout, "HLO compiled for %s backend:\n%s\n", + local_service->backend().platform()->Name().c_str(), + module.ToString().c_str()); + } else { + const ComputationTracker& tracker = local_service->computation_tracker(); + UserComputation* user_computation = + tracker.Resolve(computation.handle()).ConsumeValueOrDie(); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + std::unique_ptr module = + tracker.BuildHloModule(versioned_handle, HloModuleConfig()) + .ConsumeValueOrDie(); + + fprintf(stdout, "%s\n", module->ToString().c_str()); } - StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); - - const HloModule& module = executable.ValueOrDie()->module(); - - fprintf(stdout, "HLO for %s backend:\n%s\n", - local_service->backend().platform()->Name().c_str(), - module.ToString().c_str()); } } @@ -74,10 +89,21 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } // namespace xla int main(int argc, char** argv) { - tensorflow::port::InitMain(argv[0], &argc, &argv); + bool compile = false; + std::vector flag_list = { + {"compile", &compile, + "If true, compile the computation using the default client before " + "dumping the HLO. Otherwise dump the raw (uncompiled) HLO."}, + }; + const xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage; tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); + xla::tools::RealMain(args, compile); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc new file mode 100644 index 00000000000..850267d3195 --- /dev/null +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: dumped_computation_to_tf_graph some_binary_snapshot_proto* +// +// Dumps a tensorflow GraphDef in text format for a snapshot computation. The +// dumped graph is an HLO computation with HLO instructions as nodes and can be +// visualized on Tensorboard. Upload the dumped files on Tensorboard. +// +// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// ServiceInterface::SnapshotComputation to disk. + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::Env; + +namespace xla { +namespace tools { + +void RealMain(tensorflow::gtl::ArraySlice args) { + Client* client = ClientLibrary::LocalClientOrDie(); + for (char* arg : args) { + SessionModule module; + TF_CHECK_OK( + tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); + Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + ComputationStats stats = + client->GetComputationStats(computation).ConsumeValueOrDie(); + fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); + flags->xla_generate_hlo_graph = ".*"; + + xla::legacy_flags::HloGraphDumperFlags* dumper_flags = + xla::legacy_flags::GetHloGraphDumperFlags(); + dumper_flags->xla_hlo_dump_as_graphdef = true; + + tensorflow::gtl::ArraySlice args(argv, argc); + args.pop_front(); // Pop off the binary name, argv[0] + xla::tools::RealMain(args); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ffb2d5aefba..3a75bf64954 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -66,7 +66,8 @@ StatusOr> ReplayComputation( if (use_fake_data) { arguments = MakeFakeArgumentsOrDie(computation, client); } else { // use recorded data if available - for (const Literal& literal : module.arguments()) { + for (const auto& proto : module.arguments()) { + Literal literal(proto); TF_ASSIGN_OR_RETURN(std::unique_ptr data, client->TransferToServer(literal)); arguments.push_back(std::move(data)); @@ -74,6 +75,7 @@ StatusOr> ReplayComputation( } std::vector execute_arguments; + execute_arguments.reserve(arguments.size()); for (auto& argument : arguments) { execute_arguments.push_back(argument.get()); } @@ -100,7 +102,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { if (module.has_result()) { fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(module.result().shape()).c_str(), - LiteralUtil::ToString(module.result()).c_str()); + LiteralUtil::ToString(Literal(module.result())).c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index cf363913b15..b6538f5de07 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -37,9 +37,10 @@ int main(int argc, char **argv) { << " "; } - xla::Literal literal; + xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], - &literal)); - LOG(INFO) << "literal: " << literal.ShortDebugString(); + &literal_proto)); + xla::Literal literal(literal_proto); + LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str()); } diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 8258031a2c5..ea8b4b7b989 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -16,8 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_ #define TENSORFLOW_COMPILER_XLA_TYPES_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" +#include + namespace xla { using ::tensorflow::string; @@ -32,6 +35,8 @@ using ::tensorflow::uint16; using ::tensorflow::uint32; using ::tensorflow::uint64; +using ::Eigen::half; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 3ee5dfc9496..d467178cb52 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -33,7 +33,7 @@ namespace { // Adds a backtrace to the provided status iff the xla_status_add_backtrace flag // is set. This is useful for quickly tracing status errors observed coming out // of the service. -Status MaybeAddBacktrace(Status prior) { +Status MaybeAddBacktrace(const Status& prior) { DCHECK(!prior.ok()); if (legacy_flags::GetUtilFlags()->xla_status_add_backtrace) { return Status{prior.code(), @@ -153,16 +153,26 @@ string Reindent(tensorflow::StringPiece original, }); } +bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { + if (rank != permutation.size()) { + return false; + } + std::vector output(permutation.size(), -1); + for (auto index : permutation) { + CHECK_GE(index, 0); + CHECK_LT(index, rank); + output[index] = 0; + } + return std::find(output.begin(), output.end(), -1) == output.end(); +} + std::vector InversePermutation( tensorflow::gtl::ArraySlice input_permutation) { + DCHECK(IsPermutation(input_permutation, input_permutation.size())); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { output_permutation[input_permutation[i]] = i; } - DCHECK_EQ( - 0, std::count(output_permutation.begin(), output_permutation.end(), -1)); - DCHECK(std::is_permutation(input_permutation.begin(), input_permutation.end(), - output_permutation.begin())); return output_permutation; } @@ -176,6 +186,15 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } +bool IsIdentityPermutation(tensorflow::gtl::ArraySlice p) { + for (int64 i = 0; i < p.size(); ++i) { + if (p[i] != i) { + return false; + } + } + return true; +} + PaddingConfig MakeNoPaddingConfig(int64 rank) { PaddingConfig padding_config; for (int64 dnum = 0; dnum < rank; ++dnum) { @@ -187,6 +206,15 @@ PaddingConfig MakeNoPaddingConfig(int64 rank) { return padding_config; } +bool HasInteriorPadding(const PaddingConfig& config) { + for (const auto& dim : config.dimensions()) { + if (dim.interior_padding() != 0) { + return true; + } + } + return false; +} + string HumanReadableNumFlops(double flops, double nanoseconds) { if (nanoseconds == 0) { return "NaN FLOP/s"; diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 00f8d946f89..42d5c1d1550 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -38,6 +39,13 @@ limitations under the License. namespace xla { +// Ranks greater than 8 are very rare, so use InlinedVector to store +// the bounds and indices. And for the rare cases of ranks greater than 8, +// the InlinedVector will just behave like an std::vector<> and allocate the +// memory to store its values. +static constexpr int kInlineRank = 8; +using DimensionVector = tensorflow::gtl::InlinedVector; + // RAII timer that logs with a given label the wall clock time duration in human // readable form. This differs from base's ElapsedTimer primarily in that it // spits out the human-readable duration form. @@ -120,6 +128,14 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2) { std::equal(std::begin(c1), std::end(c1), std::begin(c2))); } +template +bool ContainersEqual(const Container1T& c1, + std::initializer_list il) { + tensorflow::gtl::ArraySlice c2{il}; + return ContainersEqual(c1, c2); +} + // Compares two containers for equality. Returns true iff the two containers // have the same size and all their elements compare equal using the predicate // p. Like std::equal, but forces size equality. @@ -130,6 +146,18 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2, std::equal(std::begin(c1), std::end(c1), std::begin(c2), p)); } +// Performs a copy of count values from src to dest, using different strides for +// source and destination. The source starting index is src_base, while the +// destination one is dest_base. +template +void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, + int64 dest_stride, tensorflow::gtl::ArraySlice src, + int64 src_base, int64 src_stride, int64 count) { + for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) { + dest[dest_base] = static_cast(src[src_base]); + } +} + // Adds some context information to the error message in a // Status. This is useful as Statuses are // propagated upwards. @@ -156,6 +184,9 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); string Reindent(tensorflow::StringPiece original, tensorflow::StringPiece indentation); +// Checks whether permutation is a permutation of the [0, rank) integer range. +bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); + // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. // @@ -166,12 +197,11 @@ template diff --git a/tensorflow/tensorboard/components/tf_backend/BUILD b/tensorflow/tensorboard/components/tf_backend/BUILD new file mode 100644 index 00000000000..50fc267dc4d --- /dev/null +++ b/tensorflow/tensorboard/components/tf_backend/BUILD @@ -0,0 +1,45 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "tf_backend", + srcs = [ + "backend.ts", + "behavior.ts", + "requestManager.ts", + "router.ts", + "runsStore.ts", + "tf-backend.html", + "urlPathHelpers.ts", + ], + path = "/tf-backend", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:lodash", + "//tensorflow/tensorboard/components/tf_imports:plottable", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/vz_sorting", + ], +) + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_backend"], + destdir = "tf-backend", + deps = [ + "//tensorflow/tensorboard/components/tf_imports_google:lib", + "//tensorflow/tensorboard/components/vz_sorting:legacy", + "//third_party/javascript/polymer/v1/polymer:lib", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_backend/backend.ts b/tensorflow/tensorboard/components/tf_backend/backend.ts index 28a5b2d0e14..023414b6b75 100644 --- a/tensorflow/tensorboard/components/tf_backend/backend.ts +++ b/tensorflow/tensorboard/components/tf_backend/backend.ts @@ -13,427 +13,596 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module TF.Backend { - export interface RunEnumeration { - histograms: string[]; - compressedHistogramTuples: string[]; - scalars: string[]; - images: string[]; - audio: string[]; - graph: boolean; - run_metadata: string[]; - } +import {compareTagNames} from '../vz-sorting/sorting'; +import {RequestManager} from './requestManager'; +import {getRouter} from './router'; +import {demoify, queryEncoder} from './urlPathHelpers'; - export interface LogdirResponse { logdir: string; } +export interface RunEnumeration { + histograms: string[]; + compressedHistogramTuples: string[]; + scalars: string[]; + images: string[]; + audio: string[]; + graph: boolean; + run_metadata: string[]; +} - export interface RunsResponse { [runName: string]: RunEnumeration; } +export interface LogdirResponse { logdir: string; } - export type RunToTag = {[run: string]: string[];}; +export interface RunsResponse { [runName: string]: RunEnumeration; } - export interface Datum { - wall_time: Date; - step: number; - } +export type RunToTag = { + [run: string]: string[]; +}; - export type ScalarDatum = Datum & Scalar; - export interface Scalar { scalar: number; } +export interface Datum { + wall_time: Date; + step: number; +} - export type HistogramDatum = Datum & Histogram; - export interface Histogram { - min: number; - max: number; - nItems?: number; - sum?: number; - sumSquares?: number; - bucketRightEdges: number[]; - bucketCounts: number[]; - } +export type ScalarDatum = Datum & Scalar; +export interface Scalar { scalar: number; } - export interface HistogramBin { - x: number; - dx: number; - y: number; - } - export type HistogramSeriesDatum = HistogramSeries & Datum; - export interface HistogramSeries { bins: HistogramBin[]; } +export interface Text { text: string; } +export type TextDatum = Datum & Text; - export type ImageDatum = Datum & Image; - export interface Image { - width: number; - height: number; - url: string; - } +export type HistogramDatum = Datum & Histogram; +export interface Histogram { + min: number; + max: number; + nItems?: number; + sum?: number; + sumSquares?: number; + bucketRightEdges: number[]; + bucketCounts: number[]; +} - export type AudioDatum = Datum & Audio; - export interface Audio { - content_type: string; - url: string; - } +export interface HistogramBin { + x: number; + dx: number; + y: number; +} +export type HistogramSeriesDatum = HistogramSeries & Datum; +export interface HistogramSeries { bins: HistogramBin[]; } - // A health pill encapsulates an overview of tensor element values. The value - // field is a list of 12 numbers that shed light on the status of the tensor. - export interface HealthPill { - node_name: string; - output_slot: number; - value: number[]; - }; - export type HealthPillDatum = Datum & HealthPill; - // A health pill response is a mapping from node name to a list of health pill - // data entries. - export interface HealthPillsResponse { [key: string]: HealthPillDatum[]; }; +export type ImageDatum = Datum & Image; +export interface Image { + width: number; + height: number; + url: string; +} + +export type AudioDatum = Datum & Audio; +export interface Audio { + content_type: string; + url: string; +} + +// A health pill encapsulates an overview of tensor element values. The value +// field is a list of 12 numbers that shed light on the status of the tensor. +export interface HealthPill { + device_name: string; + node_name: string; + output_slot: number; + dtype: string; + shape: number[]; + value: number[]; +} + +// When updating this type, keep it consistent with the HealthPill interface +// in tf_graph_common/lib/scene/scene.ts. +export type HealthPillDatum = Datum & HealthPill; +// A health pill response is a mapping from node name to a list of health pill +// data entries. +export interface HealthPillsResponse { [key: string]: HealthPillDatum[]; } + +// An object that encapsulates an alert issued by the debugger. This alert is +// sent by debugging libraries after bad values (NaN, +/- Inf) are encountered. +export interface DebuggerNumericsAlertReport { + device_name: string; + tensor_name: string; + first_timestamp: number; + nan_event_count: number; + neg_inf_event_count: number; + pos_inf_event_count: number; +} +// A DebuggerNumericsAlertReportResponse contains alerts issued by the debugger +// in ascending order of timestamp. This helps the user identify for instance +// when bad values first appeared in the model. +export type DebuggerNumericsAlertReportResponse = DebuggerNumericsAlertReport[]; + +export const TYPES = [ + 'scalar', 'histogram', 'compressedHistogram', 'graph', 'image', 'audio', + 'runMetadata', 'text' +]; +/** + * The Backend class provides a convenient and typed interface to the backend. + * + * It provides methods corresponding to the different data sources on the + * TensorBoard backend. These methods return a promise containing the data + * from the backend. This class does some post-processing on the data; for + * example, converting data elements tuples into js objects so that they can + * be accessed in a more convenient and clearly-documented fashion. + */ +export class Backend { + public requestManager: RequestManager; - export var TYPES = [ - 'scalar', 'histogram', 'compressedHistogram', 'graph', 'image', 'audio', - 'runMetadata' - ]; /** - * The Backend class provides a convenient and typed interface to the backend. - * - * It provides methods corresponding to the different data sources on the - * TensorBoard backend. These methods return a promise containing the data - * from the backend. This class does some post-processing on the data; for - * example, converting data elements tuples into js objects so that they can - * be accessed in a more convenient and clearly-documented fashion. + * Construct a Backend instance. + * @param requestManager The RequestManager, overwritable so you may + * manually clear request queue, etc. Defaults to a new RequestManager. */ - export class Backend { - public router: Router; - public requestManager: RequestManager; + constructor(requestManager?: RequestManager) { + this.requestManager = requestManager || new RequestManager(); + } - /** - * Construct a Backend instance. - * @param router the Router with info on what urls to get data from - * @param requestManager The RequestManager, overwritable so you may - * manually clear request queue, etc. Defaults to a new RequestManager. - */ - constructor(router: Router, requestManager?: RequestManager) { - this.router = router; - this.requestManager = requestManager || new RequestManager(); + /** + * Returns a promise for requesting the logdir string. + */ + public logdir(): Promise { + return this.requestManager.request(getRouter().logdir()); + } + + /** + * Returns a listing of all the available data in the TensorBoard backend. + */ + public runs(): Promise { + return this.requestManager.request(getRouter().runs()); + } + + /** + * Return a promise showing the Run-to-Tag mapping for scalar data. + */ + public scalarTags(): Promise { + return this.requestManager.request( + getRouter().pluginRoute('scalars', '/tags')); + } + + /** + * Return a promise showing the Run-to-Tag mapping for histogram data. + */ + public histogramTags(): Promise { + return this.requestManager.request( + getRouter().pluginRoute('histograms', '/tags')); + } + + /** + * Return a promise showing the Run-to-Tag mapping for image data. + */ + public imageTags(): Promise { + return this.requestManager.request( + getRouter().pluginRoute('images', '/tags')); + } + + /** + * Return a promise showing the Run-to-Tag mapping for audio data. + */ + public audioTags(): Promise { + return this.requestManager.request( + getRouter().pluginRoute('audio', '/tags')); + } + + /** + * Return a promise showing the Run-to-Tag mapping for compressedHistogram + * data. + */ + public compressedHistogramTags(): Promise { + return this.requestManager.request( + getRouter().pluginRoute('distributions', '/tags')); + } + + /** + * Returns a promise showing the Run-to-Tag mapping for profile data. + */ + public profileTags(): Promise { + let url = getRouter().pluginRoute('profile', '/tags'); + if (getRouter().isDemoMode()) { + url += '.json'; } + return this.requestManager.request(url); + } - /** - * Returns a promise for requesting the logdir string. - */ - public logdir(): Promise { - return this.requestManager.request(this.router.logdir()); + /** + * Return a promise showing list of runs that contain graphs. + */ + public graphRuns(): Promise { + return this.requestManager.request( + getRouter().pluginRoute('graphs', '/runs')); + } + + /** + * Return a promise showing the Run-to-Tag mapping for run_metadata objects. + */ + public runMetadataTags(): Promise { + return this.requestManager.request( + getRouter().pluginRoute('graphs', '/run_metadata_tags')); + } + + + /** + * Returns a promise showing the Run-to-Tag mapping for text data. + */ + public textRuns(): Promise { + return this.requestManager.request(getRouter().textRuns()); + } + + + /** + * Returns a promise containing TextDatums for given run and tag. + */ + public text(tag: string, run: string): Promise { + const url = getRouter().text(tag, run); + // tslint:disable-next-line:no-any it's convenient and harmless here + return this.requestManager.request(url).then(map((x: any) => { + x.wall_time = timeToDate(x.wall_time); + return x; + })); + } + + /** + * Return a URL to fetch a graph (cf. method 'graph'). + */ + public graphUrl(run: string, limitAttrSize?: number, largeAttrsKey?: string): + string { + const demoMode = getRouter().isDemoMode(); + const base = getRouter().pluginRoute('graphs', '/graph'); + const optional = (p) => (p != null && !demoMode || undefined) && p; + const parameters = { + 'run': run, + 'limit_attr_size': optional(limitAttrSize), + 'large_attrs_key': optional(largeAttrsKey), + }; + const extension = demoMode ? '.pbtxt' : ''; + return base + queryEncoder(parameters) + extension; + } + + public graph(run: string, limitAttrSize?: number, largeAttrsKey?: string): + Promise { + const url = this.graphUrl(run, limitAttrSize, largeAttrsKey); + return this.requestManager.request(url); + } + + /** + * Return a promise containing ScalarDatums for given run and tag. + */ + public scalar(tag: string, run: string): Promise> { + let p: Promise[]>; + const url = getRouter().pluginRunTagRoute('scalars', '/scalars')(tag, run); + p = this.requestManager.request(url); + return p.then(map(detupler(createScalar))); + } + + /** + * Returns a promise for requesting the health pills for a list of nodes. This + * route is used by the debugger plugin. + */ + public healthPills(nodeNames: string[], step?: number): + Promise { + const postData = { + 'node_names': JSON.stringify(nodeNames), + + // Events files with debugger data fall under this special run. + 'run': '__debugger_data__', + }; + if (step !== undefined) { + // The user requested health pills for a specific step. This request + // might be slow since the backend reads events sequentially from disk. + postData['step'] = step; } + return this.requestManager.request(getRouter().healthPills(), postData); + } - /** - * Returns a listing of all the available data in the TensorBoard backend. - */ - public runs(): Promise { - return this.requestManager.request(this.router.runs()); - } + /** + * Returns a promise for alerts for bad values (detected by the debugger). + * This route is used by the debugger plugin. + */ + public debuggerNumericsAlerts(): + Promise { + return this.requestManager.request( + getRouter().pluginRoute('debugger', '/numerics_alert_report')); + } - /** - * Return a promise showing the Run-to-Tag mapping for scalar data. - */ - public scalarRuns(): Promise { - return this.runs().then((x) => _.mapValues(x, 'scalars')); - } + /** + * Return a promise containing HistogramDatums for given run and tag. + */ + public histogram(tag: string, run: string): + Promise> { + let p: Promise[]>; + const url = + getRouter().pluginRunTagRoute('histograms', '/histograms')(tag, run); + p = this.requestManager.request(url); + return p.then(map(detupler(createHistogram))).then(function(histos) { + // Get the minimum and maximum values across all histograms so that the + // visualization is aligned for all timesteps. + const min = d3.min(histos, d => d.min); + const max = d3.max(histos, d => d.max); - /** - * Return a promise showing the Run-to-Tag mapping for histogram data. - */ - public histogramRuns(): Promise { - return this.runs().then((x) => _.mapValues(x, 'histograms')); - } - - /** - * Return a promise showing the Run-to-Tag mapping for image data. - */ - public imageRuns(): Promise { - return this.runs().then((x) => _.mapValues(x, 'images')); - } - - /** - * Return a promise showing the Run-to-Tag mapping for audio data. - */ - public audioRuns(): Promise { - return this.runs().then((x) => _.mapValues(x, 'audio')); - } - - /** - * Return a promise showing the Run-to-Tag mapping for compressedHistogram - * data. - */ - public compressedHistogramRuns(): Promise { - return this.runs().then((x) => _.mapValues(x, 'compressedHistograms')); - } - - /** - * Return a promise showing list of runs that contain graphs. - */ - public graphRuns(): Promise { - return this.runs().then( - (x) => { return _.keys(x).filter((k) => x[k].graph); }); - } - - /** - * Return a promise showing the Run-to-Tag mapping for run_metadata objects. - */ - public runMetadataRuns(): Promise { - return this.runs().then((x) => _.mapValues(x, 'run_metadata')); - } - - /** - * Return a promise of a graph string from the backend. - */ - public graph( - tag: string, limit_attr_size?: number, - large_attrs_key?: string): Promise { - let url = this.router.graph(tag, limit_attr_size, large_attrs_key); - return this.requestManager.request(url); - } - - /** - * Return a promise containing ScalarDatums for given run and tag. - */ - public scalar(tag: string, run: string): Promise> { - let p: Promise[]>; - let url = this.router.scalars(tag, run); - p = this.requestManager.request(url); - return p.then(map(detupler(createScalar))); - } - - /** - * Returns a promise for requesting the health pills for a list of nodes. - */ - public healthPills(nodeNames: string[]): Promise { - let postData = {'node_names': JSON.stringify(nodeNames)}; - return this.requestManager.request(this.router.healthPills(), postData); - } - - /** - * Return a promise containing HistogramDatums for given run and tag. - */ - public histogram(tag: string, run: string): - Promise> { - let p: Promise[]>; - let url = this.router.histograms(tag, run); - p = this.requestManager.request(url); - return p.then(map(detupler(createHistogram))).then(function(histos) { - // Get the minimum and maximum values across all histograms so that the - // visualization is aligned for all timesteps. - let min = d3.min(histos, d => d.min); - let max = d3.max(histos, d => d.max); - - return histos.map(function(histo, i) { - return { - wall_time: histo.wall_time, - step: histo.step, - bins: convertBins(histo, min, max) - }; - }); + return histos.map(function(histo, i) { + return { + wall_time: histo.wall_time, + step: histo.step, + bins: convertBins(histo, min, max) + }; }); - } - - /** - * Return a promise containing ImageDatums for given run and tag. - */ - public image(tag: string, run: string): Promise> { - let url = this.router.images(tag, run); - let p: Promise; - p = this.requestManager.request(url); - return p.then(map(this.createImage.bind(this))); - } - - /** - * Return a promise containing AudioDatums for given run and tag. - */ - public audio(tag: string, run: string): Promise> { - let url = this.router.audio(tag, run); - let p: Promise; - p = this.requestManager.request(url); - return p.then(map(this.createAudio.bind(this))); - } - - /** - * Returns a promise to load the string RunMetadata for given run/tag. - */ - public runMetadata(tag: string, run: string): Promise { - let url = this.router.runMetadata(tag, run); - return this.requestManager.request(url); - } - - /** - * Get compressedHistogram data. - * Unlike other methods, don't bother reprocessing this data into a nicer - * format. This is because we will deprecate this route. - */ - private compressedHistogram(tag: string, run: string): - Promise> { - let url = this.router.compressedHistograms(tag, run); - let p: Promise[]>; - p = this.requestManager.request(url); - return p.then(map(detupler((x) => x))); - } - - private createImage(x: ImageMetadata): Image&Datum { - return { - width: x.width, - height: x.height, - wall_time: timeToDate(x.wall_time), - step: x.step, - url: this.router.individualImage(x.query, x.wall_time), - }; - } - - private createAudio(x: AudioMetadata): Audio&Datum { - return { - content_type: x.content_type, - wall_time: timeToDate(x.wall_time), - step: x.step, - url: this.router.individualAudio(x.query), - }; - } - } - - /** Given a RunToTag, return sorted array of all runs */ - export function getRuns(r: RunToTag): string[] { - return _.keys(r).sort(VZ.Sorting.compareTagNames); - } - - /** Given a RunToTag, return array of all tags (sorted + dedup'd) */ - export function getTags(r: RunToTag): string[] { - return _.union.apply(null, _.values(r)).sort(VZ.Sorting.compareTagNames); - } - - /** - * Given a RunToTag and an array of runs, return every tag that appears for - * at least one run. - * Sorted, deduplicated. - */ - export function filterTags(r: RunToTag, runs: string[]): string[] { - var result = []; - runs.forEach((x) => result = result.concat(r[x])); - return _.uniq(result).sort(VZ.Sorting.compareTagNames); - } - - function timeToDate(x: number): Date { return new Date(x * 1000); }; - - /** Just a curryable map to make things cute and tidy. */ - function map(f: (x: T) => U): (arr: T[]) => U[] { - return function(arr: T[]): U[] { return arr.map(f); }; - }; - - /** - * This is a higher order function that takes a function that transforms a - * T into a G, and returns a function that takes TupleDatas and converts - * them into the intersection of a G and a Datum. - */ - function detupler(xform: (x: T) => G): (t: TupleData) => Datum & G { - return function(x: TupleData): Datum & G { - // Create a G, assert it has type - let obj = xform(x[2]); - // ... patch in the properties of datum - obj.wall_time = timeToDate(x[0]); - obj.step = x[1]; - return obj; - }; - }; - - function createScalar(x: number): Scalar { return {scalar: x}; }; - - function createHistogram(x: HistogramTuple): Histogram { - return { - min: x[0], - max: x[1], - nItems: x[2], - sum: x[3], - sumSquares: x[4], - bucketRightEdges: x[5], - bucketCounts: x[6], - }; - }; - - /** - * Takes histogram data as stored by tensorboard backend and converts it to - * the standard d3 histogram data format to make it more compatible and easier - * to visualize. When visualizing histograms, having the left edge and width - * makes things quite a bit easier. The bins are also converted to have an - * uniform width, what makes the visualization easier to understand. - * - * @param histogram A histogram from tensorboard backend. - * @param min The leftmost edge. The binning will start on it. - * @param max The rightmost edge. The binning will end on it. - * @param numBins The number of bins of the converted data. The default of 30 - * is a sensible default, using more starts to get artifacts because the event - * data is stored in buckets, and you start being able to see the aliased - * borders between each bucket. - * @return A histogram bin. Each bin has an x (left edge), a dx (width), - * and a y (count). - * - * If given rightedges are inclusive, then these left edges (x) are exclusive. - */ - export function convertBins( - histogram: Histogram, min: number, max: number, numBins = 30) { - if (histogram.bucketRightEdges.length !== histogram.bucketCounts.length) { - throw(new Error('Edges and counts are of different lengths.')); - } - - if (max === min) { - // Create bins even if all the data has a single value. - max = min * 1.1 + 1; - min = min / 1.1 - 1; - } - let binWidth = (max - min) / numBins; - let bucketLeft = min; // Use the min as the starting point for the bins. - let bucketPos = 0; - return d3.range(min, max, binWidth).map(function(binLeft) { - let binRight = binLeft + binWidth; - - // Take the count of each existing bucket, multiply it by the proportion - // of overlap with the new bin, then sum and store as the count for the - // new bin. If no overlap, will add to zero, if 100% overlap, will include - // the full count into new bin. - let binY = 0; - while (bucketPos < histogram.bucketRightEdges.length) { - // Clip the right edge because right-most edge can be infinite-sized. - let bucketRight = Math.min(max, histogram.bucketRightEdges[bucketPos]); - - let intersect = - Math.min(bucketRight, binRight) - Math.max(bucketLeft, binLeft); - let count = (intersect / (bucketRight - bucketLeft)) * - histogram.bucketCounts[bucketPos]; - - binY += intersect > 0 ? count : 0; - - // If bucketRight is bigger than binRight, than this bin is finished and - // there is data for the next bin, so don't increment bucketPos. - if (bucketRight > binRight) { - break; - } - bucketLeft = Math.max(min, bucketRight); - bucketPos++; - }; - - return {x: binLeft, dx: binWidth, y: binY}; }); } /** - * The following interfaces (TupleData, HistogramTuple, - * CompressedHistogramTuple, ImageMetadata, and AudioMetadata) describe how - * the data is sent over from the backend. + * Return a promise containing ImageDatums for given run and tag. */ - type TupleData = [number, number, T]; // wall_time, step - - // Min, Max, nItems, Sum, Sum_Squares, right edges of buckets, nItems in - // buckets - type HistogramTuple = - [number, number, number, number, number, number[], number[]]; - type CompressedHistogramTuple = [number, number][]; // percentile, value - interface ImageMetadata { - width: number; - height: number; - wall_time: number; - step: number; - query: string; + public image(tag: string, run: string): Promise> { + const url = (getRouter().pluginRunTagRoute('images', '/images')(tag, run)); + let p: Promise; + p = this.requestManager.request(url); + return p.then(map(this.createImage.bind(this))); } - interface AudioMetadata { - content_type: string; - wall_time: number; - step: number; - query: string; + + /** + * Return a promise containing AudioDatums for given run and tag. + */ + public audio(tag: string, run: string): Promise> { + const url = (getRouter().pluginRunTagRoute('audio', '/audio')(tag, run)); + let p: Promise; + p = this.requestManager.request(url); + return p.then(map(this.createAudio.bind(this))); + } + + /** + * Returns a promise containing profile data for given run and tag. + */ + public profile(tag: string, run: string): Promise { + let url = (getRouter().pluginRunTagRoute('profile', '/data')(tag, run)); + if (getRouter().isDemoMode()) { + url += '.json'; + } + return this.requestManager.request(url); + } + + /** + * Returns the url for the RunMetadata for the given run/tag. + */ + public runMetadataUrl(tag: string, run: string): string { + return getRouter().pluginRunTagRoute('graphs', '/run_metadata')(tag, run); + } + + /** + * Returns a promise to load the string RunMetadata for given run/tag. + */ + public runMetadata(tag: string, run: string): Promise { + const url = this.runMetadataUrl(tag, run); + return this.requestManager.request(url); + } + + /** + * Get compressedHistogram data. + * Unlike other methods, don't bother reprocessing this data into a nicer + * format. This is because we will deprecate this route. + */ + private compressedHistogram(tag: string, run: string): + Promise> { + const url = (getRouter().pluginRunTagRoute( + 'distributions', '/distributions')(tag, run)); + let p: Promise[]>; + p = this.requestManager.request(url); + return p.then(map(detupler((x) => x))); + } + + private createImage(x: ImageMetadata): Image&Datum { + const pluginRoute = getRouter().pluginRoute('images', '/individualImage'); + + let query = x.query; + if (pluginRoute.indexOf('?') > -1) { + // The route already has GET parameters. Append our parameters to them. + query = '&' + query; + } else { + // The route lacks GET parameters. We append them. + query = '?' + query; + } + + if (getRouter().isDemoMode()) { + query = demoify(query); + } + + let individualImageUrl = pluginRoute + query; + // Include wall_time just to disambiguate the URL and force the browser + // to reload the image when the URL changes. The backend doesn't care + // about the value. + individualImageUrl += + getRouter().isDemoMode() ? '.png' : '&ts=' + x.wall_time; + + return { + width: x.width, + height: x.height, + wall_time: timeToDate(x.wall_time), + step: x.step, + url: individualImageUrl, + }; + } + + private createAudio(x: AudioMetadata): Audio&Datum { + const pluginRoute = getRouter().pluginRoute('audio', '/individualAudio'); + + let query = x.query; + if (pluginRoute.indexOf('?') > -1) { + // The route already has GET parameters. Append our parameters to them. + query = '&' + query; + } else { + // The route lacks GET parameters. We append them. + query = '?' + query; + } + + if (getRouter().isDemoMode()) { + query = demoify(query); + } + + let individualAudioUrl = pluginRoute + query; + // Include wall_time just to disambiguate the URL and force the browser + // to reload the audio when the URL changes. The backend doesn't care + // about the value. + individualAudioUrl += + getRouter().isDemoMode() ? '.wav' : '&ts=' + x.wall_time; + + return { + content_type: x.content_type, + wall_time: timeToDate(x.wall_time), + step: x.step, + url: individualAudioUrl, + }; } } + +/** Given a RunToTag, return sorted array of all runs */ +export function getRuns(r: RunToTag): string[] { + return _.keys(r).sort(compareTagNames); +} + +/** Given a RunToTag, return array of all tags (sorted + dedup'd) */ +export function getTags(r: RunToTag): string[] { + return _.union.apply(null, _.values(r)).sort(compareTagNames); +} + +/** + * Given a RunToTag and an array of runs, return every tag that appears for + * at least one run. + * Sorted, deduplicated. + */ +export function filterTags(r: RunToTag, runs: string[]): string[] { + let result = []; + runs.forEach((x) => result = result.concat(r[x])); + return _.uniq(result).sort(compareTagNames); +} + +function timeToDate(x: number): Date { + return new Date(x * 1000); +}; + +/** Just a curryable map to make things cute and tidy. */ +function map(f: (x: T) => U): (arr: T[]) => U[] { + return function(arr: T[]): U[] { + return arr.map(f); + }; +}; + +/** + * This is a higher order function that takes a function that transforms a + * T into a G, and returns a function that takes TupleDatas and converts + * them into the intersection of a G and a Datum. + */ +function detupler(xform: (x: T) => G): (t: TupleData) => Datum & G { + return function(x: TupleData): Datum & G { + // Create a G, assert it has type + let obj = xform(x[2]); + // ... patch in the properties of datum + obj.wall_time = timeToDate(x[0]); + obj.step = x[1]; + return obj; + }; +}; + +function createScalar(x: number): Scalar { + return {scalar: x}; +} + +function createHistogram(x: HistogramTuple): Histogram { + return { + min: x[0], + max: x[1], + nItems: x[2], + sum: x[3], + sumSquares: x[4], + bucketRightEdges: x[5], + bucketCounts: x[6], + }; +} + +/** + * Takes histogram data as stored by tensorboard backend and converts it to + * the standard d3 histogram data format to make it more compatible and easier + * to visualize. When visualizing histograms, having the left edge and width + * makes things quite a bit easier. The bins are also converted to have an + * uniform width, what makes the visualization easier to understand. + * + * @param histogram A histogram from tensorboard backend. + * @param min The leftmost edge. The binning will start on it. + * @param max The rightmost edge. The binning will end on it. + * @param numBins The number of bins of the converted data. The default of 30 + * is a sensible default, using more starts to get artifacts because the event + * data is stored in buckets, and you start being able to see the aliased + * borders between each bucket. + * @return A histogram bin. Each bin has an x (left edge), a dx (width), + * and a y (count). + * + * If given rightedges are inclusive, then these left edges (x) are exclusive. + */ +export function convertBins( + histogram: Histogram, min: number, max: number, numBins = 30) { + if (histogram.bucketRightEdges.length !== histogram.bucketCounts.length) { + throw(new Error('Edges and counts are of different lengths.')); + } + + if (max === min) { + // Create bins even if all the data has a single value. + max = min * 1.1 + 1; + min = min / 1.1 - 1; + } + const binWidth = (max - min) / numBins; + let bucketLeft = min; // Use the min as the starting point for the bins. + let bucketPos = 0; + return d3.range(min, max, binWidth).map((binLeft) => { + const binRight = binLeft + binWidth; + + // Take the count of each existing bucket, multiply it by the proportion + // of overlap with the new bin, then sum and store as the count for the + // new bin. If no overlap, will add to zero, if 100% overlap, will include + // the full count into new bin. + let binY = 0; + while (bucketPos < histogram.bucketRightEdges.length) { + // Clip the right edge because right-most edge can be infinite-sized. + const bucketRight = Math.min(max, histogram.bucketRightEdges[bucketPos]); + + const intersect = + Math.min(bucketRight, binRight) - Math.max(bucketLeft, binLeft); + const count = (intersect / (bucketRight - bucketLeft)) * + histogram.bucketCounts[bucketPos]; + + binY += intersect > 0 ? count : 0; + + // If bucketRight is bigger than binRight, than this bin is finished and + // there is data for the next bin, so don't increment bucketPos. + if (bucketRight > binRight) { + break; + } + bucketLeft = Math.max(min, bucketRight); + bucketPos++; + } + + return {x: binLeft, dx: binWidth, y: binY}; + }); +} + +/** + * The following interfaces (TupleData, HistogramTuple, + * CompressedHistogramTuple, ImageMetadata, and AudioMetadata) describe how + * the data is sent over from the backend. + */ +type TupleData = [number, number, T]; // wall_time, step + +// Min, Max, nItems, Sum, Sum_Squares, right edges of buckets, nItems in +// buckets +type HistogramTuple = + [number, number, number, number, number, number[], number[]]; +type CompressedHistogramTuple = [number, number][]; // percentile, value +interface ImageMetadata { + width: number; + height: number; + wall_time: number; + step: number; + query: string; +} +interface AudioMetadata { + content_type: string; + wall_time: number; + step: number; + query: string; +} diff --git a/tensorflow/tensorboard/components/tf_backend/behavior.ts b/tensorflow/tensorboard/components/tf_backend/behavior.ts index de6590456f7..8df791eface 100644 --- a/tensorflow/tensorboard/components/tf_backend/behavior.ts +++ b/tensorflow/tensorboard/components/tf_backend/behavior.ts @@ -12,134 +12,137 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import {getRuns, getTags, TYPES} from './backend'; -module TF.Backend { - export var Behavior = { - properties: { - /** *** Required properties *** */ - /** Data type. One of TF.Backend.TYPES */ - dataType: { - type: String, - observer: '_throwErrorOnUnrecognizedType', - }, - - /** TF.Backend.Backend for data loading. */ - backend: { - type: Object, - }, - - /** Should it automatically load when configured ready? Default true. */ - autoLoad: { - type: Boolean, - value: true, - }, - - /** *** Component-provided properties *** */ - /** Every tag available for data type (sorted, dedpulicated) */ - tags: { - type: Array, - readOnly: true, - notify: true, - }, - - /** Every run available for data type (sorted) */ - runs: { - type: Array, - readOnly: true, - notify: true, - }, - - /** Mapping from runs to tags for the data type */ - run2tag: { - type: Object, - readOnly: true, - notify: true, - }, - - /** Promise provider for the data. Useful for passing to subcomponents */ - dataProvider: - {type: Function, computed: '_getDataProvider(dataType, backend)'}, - - /** Has the dashboard loaded yet? */ - loadState: { - type: String, - value: 'noload', // [noload, pending, loaded, failure] - readOnly: true, - }, - - /** - * True if dashboard has loaded, and no tags were found. - * Persists through subsequent reloads (ie. still true while - * next load is pending) so warning won't flash away every reload - * when there is no data. - */ - dataNotFound: { - type: Boolean, - value: false, - readOnly: true, - } - +/** @polymerBehavior */ +export const BackendBehavior = { + properties: { + /** *** Required properties *** */ + /** Data type. One of Backend.TYPES */ + dataType: { + type: String, + observer: '_throwErrorOnUnrecognizedType', }, - observers: ['_do_autoLoad(dataType, backend, autoLoad)'], - /** - * Reloading works in two steps: - * Backend reload, which gets metadata on available runs, tags, etc from - * the backend. - * Frontend reload, which loads new data for each chart or visual display. - * Backend reload logic is provided by this behaivor. The frontend reload - * logic should be provided elsewhere, since it is component-specific. - * To keep things simple and consistent, we do the backend reload first, - * and the frontend reload afterwards. - */ - reload: function() { - return this.backendReload().then( - (x) => { return this.frontendReload(); }); + + /** Backend for data loading. */ + backend: { + type: Object, }, + + /** Should it automatically load when configured ready? Default true. */ + autoLoad: { + type: Boolean, + value: true, + }, + + /** *** Component-provided properties *** */ + /** Every tag available for data type (sorted, dedpulicated) */ + tags: { + type: Array, + readOnly: true, + notify: true, + }, + + /** Every run available for data type (sorted) */ + runs: { + type: Array, + readOnly: true, + notify: true, + }, + + /** Mapping from runs to tags for the data type */ + run2tag: { + type: Object, + readOnly: true, + notify: true, + }, + + /** Promise provider for the data. Useful for passing to subcomponents */ + dataProvider: + {type: Function, computed: '_getDataProvider(dataType, backend)'}, + + /** Has the dashboard loaded yet? */ + loadState: { + type: String, + value: 'noload', // [noload, pending, loaded, failure] + readOnly: true, + }, + /** - * Load data from backend and then set run2tag, tags, runs, and loadState. - * Returns a promise that resolves/rejects when data is loaded. + * True if dashboard has loaded, and no tags were found. + * Persists through subsequent reloads (ie. still true while + * next load is pending) so warning won't flash away every reload + * when there is no data. */ - backendReload: function() { - if (this.dataType == null) { - throw new Error('TF.Backend.Behavior: Need a dataType to reload.'); - } - if (this.backend == null) { - throw new Error('TF.Backend.Behavior: Need a backend to reload.'); - } - var runsRoute = this.backend[this.dataType + 'Runs'].bind(this.backend); - this._setLoadState('pending'); - return runsRoute().then( - (x) => { - this._setLoadState('loaded'); - if (_.isEqual(x, this.run2tag)) { - // If x and run2tag are equal, let's avoid updating everything - // since that can needlessly trigger run changes, reloads, etc - return x; - } - this._setRun2tag(x); - var tags = TF.Backend.getTags(x); - this._setDataNotFound(tags.length === 0); - this._setTags(tags); - this._setRuns(TF.Backend.getRuns(x)); + dataNotFound: { + type: Boolean, + value: false, + readOnly: true, + } + + }, + observers: ['_do_autoLoad(dataType, backend, autoLoad)'], + /** + * Reloading works in two steps: + * Backend reload, which gets metadata on available runs, tags, etc from + * the backend. + * Frontend reload, which loads new data for each chart or visual display. + * Backend reload logic is provided by this behavior. The frontend reload + * logic should be provided elsewhere, since it is component-specific. + * To keep things simple and consistent, we do the backend reload first, + * and the frontend reload afterwards. + */ + reload() { + return this.backendReload().then((x) => { + return this.frontendReload(); + }); + }, + /** + * Load data from backend and then set run2tag, tags, runs, and loadState. + * Returns a promise that resolves/rejects when data is loaded. + */ + backendReload() { + if (this.dataType == null) { + throw new Error('BackendBehavior: Need a dataType to reload.'); + } + if (this.backend == null) { + throw new Error('BackendBehavior: Need a backend to reload.'); + } + const runsRoute = (this.backend[this.dataType + 'Runs'] || + this.backend[this.dataType + 'Tags']) + .bind(this.backend); + this._setLoadState('pending'); + return runsRoute().then( + (x) => { + this._setLoadState('loaded'); + if (_.isEqual(x, this.run2tag)) { + // If x and run2tag are equal, let's avoid updating everything + // since that can needlessly trigger run changes, reloads, etc return x; - }, - (fail) => { - this._setLoadState('failure'); - return fail; - }); - }, - _do_autoLoad: function(type, backend, autoLoad) { - if (autoLoad) { - this.reload(); - }; - }, - _getDataProvider: function(dataType, backend) { - return this.backend[this.dataType].bind(this.backend); - }, - _throwErrorOnUnrecognizedType: function(dataType) { - if (TF.Backend.TYPES.indexOf(dataType) === -1) { - throw new Error('TF.Backend.Behavior: Unknown dataType ' + dataType); - } - }, - }; -} + } + this._setRun2tag(x); + const tags = getTags(x); + this._setDataNotFound(tags.length === 0); + this._setTags(tags); + this._setRuns(getRuns(x)); + return x; + }, + (fail) => { + this._setLoadState('failure'); + return fail; + }); + }, + _do_autoLoad(type, backend, autoLoad) { + if (autoLoad) { + this.reload(); + } + }, + _getDataProvider(dataType, backend) { + return this.backend[this.dataType].bind(this.backend); + }, + _throwErrorOnUnrecognizedType(dataType) { + if (TYPES.indexOf(dataType) === -1) { + throw new Error('BackendBehavior: Unknown dataType ' + dataType); + } + }, +}; diff --git a/tensorflow/tensorboard/components/tf_backend/requestManager.ts b/tensorflow/tensorboard/components/tf_backend/requestManager.ts index 1dfc3348b59..0fa198416e8 100644 --- a/tensorflow/tensorboard/components/tf_backend/requestManager.ts +++ b/tensorflow/tensorboard/components/tf_backend/requestManager.ts @@ -13,166 +13,165 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module TF.Backend { - interface ResolveReject { resolve: Function; reject: Function; } - /** - * Manages many fetch requests. Launches up to nSimultaneousRequests - * simultaneously, and maintains a LIFO queue of requests to process when - * more urls are requested than can be handled at once. The queue can be - * cleared. - * - * When a request is made, a Promise is returned which resolves with the - * parsed JSON result from the request. - */ - export class RequestCancellationError extends Error { - public name = 'RequestCancellationError'; - } +interface ResolveReject { + resolve: Function; + reject: Function; +} +/** + * Manages many fetch requests. Launches up to nSimultaneousRequests + * simultaneously, and maintains a LIFO queue of requests to process when + * more urls are requested than can be handled at once. The queue can be + * cleared. + * + * When a request is made, a Promise is returned which resolves with the + * parsed JSON result from the request. + */ +export class RequestCancellationError extends Error { + public name = 'RequestCancellationError'; +} - export class RequestNetworkError extends Error { - public name: string; - public req: XMLHttpRequest; - public url: string; +export class RequestNetworkError extends Error { + public name: string; + public req: XMLHttpRequest; + public url: string; - constructor(req: XMLHttpRequest, url) { - super(); - this.message = `RequestNetworkError: ${req.status} at ${url}`; - this.name = 'RequestNetworkError'; - this.req = req; - this.url = url; - } - } - - export class RequestManager { - private _queue: ResolveReject[]; - private _maxRetries: number; - private _nActiveRequests: number; - private _nSimultaneousRequests: number; - - constructor(nSimultaneousRequests = 10, maxRetries = 3) { - this._queue = []; - this._nActiveRequests = 0; - this._nSimultaneousRequests = nSimultaneousRequests; - this._maxRetries = maxRetries; - } - - /** - * Gives a promise that loads assets from given url (respects queuing). If - * postData is provided, this request will use POST, not GET. This is an - * object mapping POST keys to string values. - */ - public request( - url: string, postData?: {[key: string]: string}): Promise { - var promise = new Promise((resolve, reject) => { - var resolver = {resolve: resolve, reject: reject}; - this._queue.push(resolver); - this.launchRequests(); - }) - .then(() => { - return this.promiseWithRetries( - url, this._maxRetries, postData); - }) - .then( - (response) => { - // Success - Let's free space for another active - // reqest, and launch it - this._nActiveRequests--; - this.launchRequests(); - return response; - }, - (rejection) => { - if (rejection.name === 'RequestNetworkError') { - // If we failed due to network error, we should - // decrement - // _nActiveRequests because this request was - // active - this._nActiveRequests--; - this.launchRequests(); - } - return Promise.reject(rejection); - }); - return promise; - } - - public clearQueue() { - while (this._queue.length > 0) { - this._queue.pop().reject( - new RequestCancellationError('Request cancelled by clearQueue')); - } - } - - /* Return number of currently pending requests */ - public activeRequests(): number { - return this._nActiveRequests; - } - - /* Return total number of outstanding requests (includes queue) */ - public outstandingRequests(): number { - return this._nActiveRequests + this._queue.length; - } - - private launchRequests() { - while (this._nActiveRequests < this._nSimultaneousRequests && - this._queue.length > 0) { - this._nActiveRequests++; - this._queue.pop().resolve(); - } - } - - /** - * Try to request a given URL using overwritable _promiseFromUrl method. - * If the request fails for any reason, we will retry up to maxRetries - * times. In practice, this will help us paper over transient network issues - * like '502 Bad Gateway'. - * By default, Chrome displays network errors in console, so - * the user will be able to tell when the requests are failing. I think this - * is a feature, if the request failures and retries are causing any - * pain to users, they can see it and file issues. - */ - private promiseWithRetries( - url: string, - maxRetries: number, - postData?: {[key: string]: string}) { - var success = (x) => x; - var failure = (x) => { - if (maxRetries > 0) { - return this.promiseWithRetries(url, maxRetries - 1, postData); - } else { - return Promise.reject(x); - } - }; - return this._promiseFromUrl(url, postData).then(success, failure); - } - - /* Actually get promise from url using XMLHttpRequest */ - protected _promiseFromUrl(url:string, postData?: {[key: string]: string}) { - return new Promise((resolve, reject) => { - let req = new XMLHttpRequest(); - req.open(postData ? 'POST' : 'GET', url); - - let formData; - if (postData) { - // We are to make a POST request. - formData = new FormData(); - for (let postKey in postData) { - if (postKey) { - // The linter requires 'for in' loops to be filtered by an if - // condition. - formData.append(postKey, postData[postKey]); - } - } - } - req.onload = function() { - if (req.status === 200) { - resolve(JSON.parse(req.responseText)); - } else { - reject(new RequestNetworkError(req, url)); - } - }; - req.onerror = function() { - reject(new RequestNetworkError(req, url)); - }; - req.send(formData); - }); - } + constructor(req: XMLHttpRequest, url) { + super(); + this.message = `RequestNetworkError: ${req.status} at ${url}`; + this.name = 'RequestNetworkError'; + this.req = req; + this.url = url; + } +} + +export class RequestManager { + private _queue: ResolveReject[]; + private _maxRetries: number; + private _nActiveRequests: number; + private _nSimultaneousRequests: number; + + constructor(nSimultaneousRequests = 10, maxRetries = 3) { + this._queue = []; + this._nActiveRequests = 0; + this._nSimultaneousRequests = nSimultaneousRequests; + this._maxRetries = maxRetries; + } + + /** + * Gives a promise that loads assets from given url (respects queuing). If + * postData is provided, this request will use POST, not GET. This is an + * object mapping POST keys to string values. + */ + public request(url: string, postData?: {[key: string]: string}): + Promise { + const promise = + new Promise((resolve, reject) => { + const resolver = {resolve: resolve, reject: reject}; + this._queue.push(resolver); + this.launchRequests(); + }) + .then(() => { + return this.promiseWithRetries(url, this._maxRetries, postData); + }) + .then( + (response) => { + // Success - Let's free space for another active + // request, and launch it + this._nActiveRequests--; + this.launchRequests(); + return response; + }, + (rejection) => { + if (rejection.name === 'RequestNetworkError') { + // If we failed due to network error, we should + // decrement + // _nActiveRequests because this request was + // active + this._nActiveRequests--; + this.launchRequests(); + } + return Promise.reject(rejection); + }); + return promise; + } + + public clearQueue() { + while (this._queue.length > 0) { + this._queue.pop().reject( + new RequestCancellationError('Request cancelled by clearQueue')); + } + } + + /* Return number of currently pending requests */ + public activeRequests(): number { + return this._nActiveRequests; + } + + /* Return total number of outstanding requests (includes queue) */ + public outstandingRequests(): number { + return this._nActiveRequests + this._queue.length; + } + + private launchRequests() { + while (this._nActiveRequests < this._nSimultaneousRequests && + this._queue.length > 0) { + this._nActiveRequests++; + this._queue.pop().resolve(); + } + } + + /** + * Try to request a given URL using overwritable _promiseFromUrl method. + * If the request fails for any reason, we will retry up to maxRetries + * times. In practice, this will help us paper over transient network issues + * like '502 Bad Gateway'. + * By default, Chrome displays network errors in console, so + * the user will be able to tell when the requests are failing. I think this + * is a feature, if the request failures and retries are causing any + * pain to users, they can see it and file issues. + */ + private promiseWithRetries( + url: string, maxRetries: number, postData?: {[key: string]: string}) { + var success = (x) => x; + var failure = (x) => { + if (maxRetries > 0) { + return this.promiseWithRetries(url, maxRetries - 1, postData); + } else { + return Promise.reject(x); + } + }; + return this._promiseFromUrl(url, postData).then(success, failure); + } + + /* Actually get promise from url using XMLHttpRequest */ + protected _promiseFromUrl(url: string, postData?: {[key: string]: string}) { + return new Promise((resolve, reject) => { + let req = new XMLHttpRequest(); + req.open(postData ? 'POST' : 'GET', url); + + let formData; + if (postData) { + // We are to make a POST request. + formData = new FormData(); + for (let postKey in postData) { + if (postKey) { + // The linter requires 'for in' loops to be filtered by an if + // condition. + formData.append(postKey, postData[postKey]); + } + } + } + req.onload = function() { + if (req.status === 200) { + resolve(JSON.parse(req.responseText)); + } else { + reject(new RequestNetworkError(req, url)); + } + }; + req.onerror = function() { + reject(new RequestNetworkError(req, url)); + }; + req.send(formData); + }); } } diff --git a/tensorflow/tensorboard/components/tf_backend/router.ts b/tensorflow/tensorboard/components/tf_backend/router.ts index d2c8191cc86..598546004e1 100644 --- a/tensorflow/tensorboard/components/tf_backend/router.ts +++ b/tensorflow/tensorboard/components/tf_backend/router.ts @@ -12,94 +12,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module TF.Backend { - export type RunTagUrlFn = (tag: string, run: string) => string; - export interface Router { - logdir: () => string; - runs: () => string; - scalars: RunTagUrlFn; - histograms: RunTagUrlFn; - compressedHistograms: RunTagUrlFn; - images: RunTagUrlFn; - individualImage: (query: string, wallTime: number) => string; - audio: RunTagUrlFn; - individualAudio: (query: string) => string; - graph: (run: string, limit_attr_size?: number, large_attrs_key?: string) - => string; - runMetadata: RunTagUrlFn; - healthPills: () => string; - }; +import {demoify, queryEncoder} from './urlPathHelpers' - /** - * The standard router for communicating with the TensorBoard backend - * @param dataDir {string} The base prefix for finding data on server. - * @param demoMode {boolean} Whether to modify urls for filesystem demo usage. - */ - export function router(dataDir = '/data', demoMode = false): Router { - var clean = demoMode ? demoify : (x) => x; - if (dataDir[dataDir.length - 1] === '/') { - dataDir = dataDir.slice(0, dataDir.length - 1); - } - function standardRoute(route: string, demoExtension = '.json'): - ((tag: string, run: string) => string) { - return function(tag: string, run: string): string { - var url = - dataDir + '/' + route + clean(queryEncoder({tag: tag, run: run})); - if (demoMode) { - url += demoExtension; - } - return url; - }; - } - function individualImageUrl(query: string, wallTime: number) { - var url = dataDir + '/' + clean('individualImage?' + query); - // Include wall_time just to disambiguate the URL and force the browser - // to reload the image when the URL changes. The backend doesn't care - // about the value. - url += demoMode ? '.png' : '&ts=' + wallTime; - return url; - } - function individualAudioUrl(query: string) { - var url = dataDir + '/' + clean('individualAudio?' + query); - if (demoMode) { - url += '.wav'; - } - return url; - } - function graphUrl(run: string, limit_attr_size?: number, - large_attrs_key?: string) { - let query_params = [['run', clean(run)]]; - if (limit_attr_size != null && !demoMode) { - query_params.push(['limit_attr_size', String(limit_attr_size)]); - } - if (large_attrs_key != null && !demoMode) { - query_params.push(['large_attrs_key', large_attrs_key]); - } - let query = query_params - .map(param => { - return param[0] + '=' + encodeURIComponent(param[1]); - }) - .join('&'); - var url = dataDir + '/graph' + clean('?' + query); - if (demoMode) { - url += '.pbtxt'; - } - return url; - } - return { - logdir: () => dataDir + '/logdir', - runs: () => dataDir + '/runs' + (demoMode ? '.json' : ''), - individualImage: individualImageUrl, - individualAudio: individualAudioUrl, - graph: graphUrl, - scalars: standardRoute('scalars'), - histograms: standardRoute('histograms'), - compressedHistograms: standardRoute('compressedHistograms'), - images: standardRoute('images'), - audio: standardRoute('audio'), - runMetadata: standardRoute('run_metadata', '.pbtxt'), - healthPills: () => dataDir + '/plugin/debugger/health_pills', - }; - }; +export type RunTagUrlFn = (tag: string, run: string) => string; + +export interface Router { + logdir: () => string; + runs: () => string; + isDemoMode: () => boolean; + textRuns: () => string; + text: RunTagUrlFn; + healthPills: () => string; + pluginRoute: (pluginName: string, route: string) => string; + pluginRunTagRoute: (pluginName: string, route: string) => RunTagUrlFn; +} +; + +/** + * Create a router for communicating with the TensorBoard backend. You + * can pass this to `setRouter` to make it the global router. + * + * @param dataDir {string} The base prefix for finding data on server. + * @param demoMode {boolean} Whether to modify urls for filesystem demo usage. + */ +export function createRouter(dataDir = 'data', demoMode = false): Router { + var clean = demoMode ? demoify : (x) => x; + if (dataDir[dataDir.length - 1] === '/') { + dataDir = dataDir.slice(0, dataDir.length - 1); + } + function standardRoute(route: string, demoExtension = '.json'): + ((tag: string, run: string) => string) { + return function(tag: string, run: string): string { + var url = + dataDir + '/' + route + clean(queryEncoder({tag: tag, run: run})); + if (demoMode) { + url += demoExtension; + } + return url; + }; + } + function pluginRoute(pluginName: string, route: string): string { + return `${dataDir}/plugin/${pluginName}${route}`; + } + function pluginRunTagRoute(pluginName: string, route: string): + ((tag: string, run: string) => string) { + const base = pluginRoute(pluginName, route); + return (tag, run) => base + clean(queryEncoder({tag, run})); + } + return { + logdir: () => dataDir + '/logdir', + runs: () => dataDir + '/runs' + (demoMode ? '.json' : ''), + isDemoMode: () => demoMode, + healthPills: () => dataDir + '/plugin/debugger/health_pills', + textRuns: () => dataDir + '/plugin/text/runs' + (demoMode ? '.json' : ''), + text: standardRoute('plugin/text/text'), + pluginRoute, + pluginRunTagRoute, + }; +}; + +let _router: Router = createRouter(); + +/** + * @return {Router} the global router + */ +export function getRouter(): Router { + return _router; +} + +/** + * Set the global router, to be returned by future calls to `getRouter`. + * You may wish to invoke this if you are running a demo server with a + * custom path prefix, or if you have customized the TensorBoard backend + * to use a different path. + * + * @param {Router} router the new global router + */ +export function setRouter(router: Router): void { + if (router == null) { + throw new Error('Router required, but got: ' + router); + } + _router = router; } diff --git a/tensorflow/tensorboard/components/tf_backend/runsStore.ts b/tensorflow/tensorboard/components/tf_backend/runsStore.ts new file mode 100644 index 00000000000..bcaff994ce8 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_backend/runsStore.ts @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {RequestManager} from './requestManager'; +import {getRouter} from './router'; + +let runs: string[] = []; + +export type Listener = () => void; +const listeners = new Set(); + +const requestManager = new RequestManager(1 /* simultaneous request */); + +/** + * Register a listener (nullary function) to be called when new runs are + * available. + */ +export function addListener(listener: Listener): void { + listeners.add(listener); +} + +/** + * Remove a listener registered with `addListener`. + */ +export function removeListener(listener: Listener): void { + listeners.delete(listener); +} + +/** + * Asynchronously load or reload the runs data. Listeners will be + * invoked if this causes the runs data to change. + * + * @see addListener + * @return {Promise} a promise that resolves when the runs have + * loaded + */ +export function fetchRuns(): Promise { + const url = getRouter().runs(); + return requestManager.request(url).then(newRuns => { + if (!_.isEqual(runs, newRuns)) { + runs = newRuns; + listeners.forEach(listener => { + listener(); + }); + } + }); +} + +/** + * Get the current list of runs. If no data is available, this will be + * an empty array (i.e., there is no distinction between "no runs" and + * "no runs yet"). + */ +export function getRuns(): string[] { + return runs.slice(); +} diff --git a/tensorflow/tensorboard/components/tf_backend/test/BUILD b/tensorflow/tensorboard/components/tf_backend/test/BUILD new file mode 100644 index 00000000000..da70f8a9daa --- /dev/null +++ b/tensorflow/tensorboard/components/tf_backend/test/BUILD @@ -0,0 +1,32 @@ +package( + default_testonly = True, + default_visibility = ["//tensorflow/tensorboard:internal"], +) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "test", + srcs = [ + "tests.html", + "backendTests.ts", + "behaviorTests.ts", + "requestManagerTests.ts", + ] + glob(["data/**"]), + path = "/tf-backend/test", + deps = [ + "//tensorflow/tensorboard/components/tf_backend", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + testonly = 0, + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts b/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts index 4b91e9f62c8..029c8359125 100644 --- a/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts +++ b/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts @@ -12,290 +12,283 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -var assert = chai.assert; +import {Backend, convertBins, filterTags, getRuns, getTags, RunToTag, TYPES} from '../backend'; +import {RequestManager} from '../requestManager'; +import {createRouter, setRouter} from '../router'; +import {BAD_CHARACTERS, demoify, queryEncoder} from '../urlPathHelpers'; -module TF.Backend { - describe('urlPathHelpers', function() { - let demoify = TF.Backend.demoify; - let encode = TF.Backend.queryEncoder; - it('demoify works as expected', function() { - let demoified = demoify(BAD_CHARACTERS); - let all_clean = ''; - for (let i = 0; i < BAD_CHARACTERS.length; i++) { - all_clean += '_'; - } - assert.equal(demoified, all_clean, 'cleaning the BAD_CHARACTERS works'); - assert.equal(demoify('foozod'), 'foozod', 'doesnt change safe string'); - assert.equal(demoify('foo zod (2)'), 'foo_zod__2_', 'simple case'); - }); +describe('urlPathHelpers', () => { + it('demoify works as expected', () => { + const demoified = demoify(BAD_CHARACTERS); + let allClean = ''; + for (let i = 0; i < BAD_CHARACTERS.length; i++) { + allClean += '_'; + } + chai.assert.equal(demoified, allClean, 'cleaning the BAD_CHARACTERS works'); + chai.assert.equal(demoify('foozod'), 'foozod', 'doesnt change safe string'); + chai.assert.equal(demoify('foo zod (2)'), 'foo_zod__2_', 'simple case'); + }); - it('queryEncoder works with demoify on spaces and parens', function() { - let params = {foo: 'something with spaces and (parens)'}; - let actual = demoify(encode(params)); - let expected = '_foo_something_with_spaces_and__28parens_29'; - assert.equal(actual, expected); + it('queryEncoder works with demoify on spaces and parens', () => { + const params = {foo: 'something with spaces and (parens)'}; + const actual = demoify(queryEncoder(params)); + const expected = '_foo_something_with_spaces_and__28parens_29'; + chai.assert.equal(actual, expected); + }); +}); + +function assertIsDatum(x) { + chai.assert.isNumber(x.step); + chai.assert.instanceOf(x.wall_time, Date); +} + +describe('backend tests', () => { + let backend: Backend; + let rm: RequestManager; + const base = 'data'; + const demoRouter = createRouter(base, /*demoMode=*/true); + beforeEach(() => { + // Construct a demo Backend (third param is true) + setRouter(demoRouter); + backend = new Backend(); + rm = new RequestManager(); + }); + + it('runs are loaded properly', (done) => { + const runsResponse = backend.runs(); + const actualRuns = rm.request(demoRouter.runs()); + Promise.all([runsResponse, actualRuns]).then((values) => { + chai.assert.deepEqual(values[0], values[1]); + done(); }); }); - function assertIsDatum(x) { - assert.isNumber(x.step); - assert.instanceOf(x.wall_time, Date); + it('scalars are loaded properly', (done) => { + backend.scalar('cross_entropy (1)', 'run1').then((s) => { + // just check the data got reformatted properly + const aScalar = s[s.length - 1]; + assertIsDatum(aScalar); + chai.assert.isNumber(aScalar.scalar); + // verify date conversion works + chai.assert.equal(aScalar.wall_time.valueOf(), 40000); + done(); + }); + }); + + it('histograms are loaded properly', (done) => { + backend.histogram('histo1', 'run1').then((histos) => { + const histo = histos[0]; + assertIsDatum(histo); + chai.assert.instanceOf(histo.bins, Array); + done(); + }); + }); + + it('all registered types have handlers', () => { + TYPES.forEach((t: string) => { + chai.assert.isDefined(backend[t], t); + chai.assert.isDefined(backend[t + 'Runs'], t + 'Runs'); + }); + }); + + it('images are loaded properly', (done) => { + backend.image('im1', 'run1').then((images) => { + const image = images[0]; + assertIsDatum(image); + chai.assert.isNumber(image.width); + chai.assert.isNumber(image.height); + done(); + }); + }); + + it('audio is loaded properly', (done) => { + backend.audio('audio1', 'run1').then((audioClips) => { + const audio = audioClips[0]; + assertIsDatum(audio); + chai.assert.equal(audio.content_type, 'audio/wav'); + done(); + }); + }); + + it('trailing slash removed from base route', () => { + const r = createRouter('foo/'); + chai.assert.equal(r.runs(), 'foo/runs'); + }); + + it('run helper methods work', (done) => { + const scalar = {run1: ['cross_entropy (1)'], fake_run_no_data: ['scalar2']}; + const image = {run1: ['im1'], fake_run_no_data: ['im1', 'im2']}; + const audio = {run1: ['audio1'], fake_run_no_data: ['audio1', 'audio2']}; + const runMetadata = {run1: ['step99'], fake_run_no_data: ['step99']}; + const graph = ['fake_run_no_data']; + let count = 0; + function next() { + count++; + if (count === 4) { + done(); + } + } + backend.scalarTags().then((x) => { + chai.assert.deepEqual(x, scalar); + next(); + }); + backend.imageTags().then((x) => { + chai.assert.deepEqual(x, image); + next(); + }); + backend.audioTags().then((x) => { + chai.assert.deepEqual(x, audio); + next(); + }); + backend.runMetadataTags().then((x) => { + chai.assert.deepEqual(x, runMetadata); + next(); + }); + backend.graphRuns().then((x) => { + chai.assert.deepEqual(x, graph); + next(); + }); + }); + + it('runToTag helpers work', () => { + const r2t: RunToTag = { + run1: ['foo', 'bar', 'zod'], + run2: ['zod', 'zoink'], + a: ['foo', 'zod'] + }; + const empty1: RunToTag = {}; + const empty2: RunToTag = {run1: [], run2: []}; + chai.assert.deepEqual(getRuns(r2t), ['a', 'run1', 'run2']); + chai.assert.deepEqual(getTags(r2t), ['bar', 'foo', 'zod', 'zoink']); + chai.assert.deepEqual(filterTags(r2t, ['run1', 'run2']), getTags(r2t)); + chai.assert.deepEqual(filterTags(r2t, ['run1']), ['bar', 'foo', 'zod']); + chai.assert.deepEqual( + filterTags(r2t, ['run2', 'a']), ['foo', 'zod', 'zoink']); + + chai.assert.deepEqual(getRuns(empty1), []); + chai.assert.deepEqual(getTags(empty1), []); + + chai.assert.deepEqual(getRuns(empty2), ['run1', 'run2']); + chai.assert.deepEqual(getTags(empty2), []); + }); +}); + +describe('Verify that the histogram format conversion works.', () => { + + function assertHistogramEquality(h1, h2) { + h1.forEach((b1, i) => { + const b2 = h2[i]; + chai.assert.closeTo(b1.x, b2.x, 1e-10); + chai.assert.closeTo(b1.dx, b2.dx, 1e-10); + chai.assert.closeTo(b1.y, b2.y, 1e-10); + }); } - describe('backend tests', function() { - let backend: Backend; - let rm: RequestManager; - let base = 'data'; - let demoRouter = TF.Backend.router(base, true); - beforeEach(function() { - // Construct a demo Backend (third param is true) - backend = new Backend(demoRouter); - rm = new RequestManager(); - }); - - it('runs are loaded properly', function(done) { - let runsResponse = backend.runs(); - let actualRuns = rm.request(demoRouter.runs()); - Promise.all([runsResponse, actualRuns]).then((values) => { - assert.deepEqual(values[0], values[1]); - done(); - }); - }); - - it('scalars are loaded properly', function(done) { - backend.scalar('cross_entropy (1)', 'run1').then((s) => { - // just check the data got reformatted properly - let aScalar = s[s.length - 1]; - assertIsDatum(aScalar); - assert.isNumber(aScalar.scalar); - // verify date conversion works - assert.equal(aScalar.wall_time.valueOf(), 40000); - done(); - }); - }); - - it('histograms are loaded properly', function(done) { - backend.histogram('histo1', 'run1').then((histos) => { - let histo = histos[0]; - assertIsDatum(histo); - assert.instanceOf(histo.bins, Array); - done(); - }); - }); - - it('all registered types have handlers', function() { - TYPES.forEach((t: string) => { - assert.isDefined(backend[t], t); - assert.isDefined(backend[t + 'Runs'], t + 'Runs'); - }); - }); - - it('images are loaded properly', function(done) { - backend.image('im1', 'run1').then((images) => { - let image = images[0]; - assertIsDatum(image); - assert.isNumber(image.width); - assert.isNumber(image.height); - let nonDemoQuery = 'index=0&tag=im1&run=run1'; - let expectedUrl = demoRouter.individualImage(nonDemoQuery, 10.0); - assert.equal(image.url, expectedUrl); - done(); - }); - }); - - it('audio is loaded properly', function(done) { - backend.audio('audio1', 'run1').then((audio_clips) => { - let audio = audio_clips[0]; - assertIsDatum(audio); - assert.equal(audio.content_type, 'audio/wav'); - let nonDemoQuery = 'index=0&tag=audio1&run=run1'; - let expectedUrl = demoRouter.individualAudio(nonDemoQuery); - assert.equal(audio.url, expectedUrl); - done(); - }); - }); - - it('trailing slash removed from base route', function() { - let r = TF.Backend.router('foo/'); - assert.equal(r.runs(), 'foo/runs'); - }); - - it('run helper methods work', function(done) { - let scalar = {run1: ['cross_entropy (1)'], fake_run_no_data: ['scalar2']}; - let image = {run1: ['im1'], fake_run_no_data: ['im1', 'im2']}; - let audio = {run1: ['audio1'], fake_run_no_data: ['audio1', 'audio2']}; - let runMetadata = {run1: ['step99'], fake_run_no_data: ['step99']}; - let graph = ['fake_run_no_data']; - let count = 0; - function next() { - count++; - if (count === 4) { - done(); - } - } - backend.scalarRuns().then((x) => { - assert.deepEqual(x, scalar); - next(); - }); - backend.imageRuns().then((x) => { - assert.deepEqual(x, image); - next(); - }); - backend.audioRuns().then((x) => { - assert.deepEqual(x, audio); - next(); - }); - backend.runMetadataRuns().then((x) => { - assert.deepEqual(x, runMetadata); - next(); - }); - backend.graphRuns().then((x) => { - assert.deepEqual(x, graph); - next(); - }); - }); - - it('runToTag helpers work', function() { - let r2t: RunToTag = { - run1: ['foo', 'bar', 'zod'], - run2: ['zod', 'zoink'], - a: ['foo', 'zod'] - }; - let empty1: RunToTag = {}; - let empty2: RunToTag = {run1: [], run2: []}; - assert.deepEqual(getRuns(r2t), ['a', 'run1', 'run2']); - assert.deepEqual(getTags(r2t), ['bar', 'foo', 'zod', 'zoink']); - assert.deepEqual(filterTags(r2t, ['run1', 'run2']), getTags(r2t)); - assert.deepEqual(filterTags(r2t, ['run1']), ['bar', 'foo', 'zod']); - assert.deepEqual(filterTags(r2t, ['run2', 'a']), ['foo', 'zod', 'zoink']); - - assert.deepEqual(getRuns(empty1), []); - assert.deepEqual(getTags(empty1), []); - - assert.deepEqual(getRuns(empty2), ['run1', 'run2']); - assert.deepEqual(getTags(empty2), []); - }); + it('Throws and error if the inputs are of different lengths', () => { + chai.assert.throws(() => { + convertBins( + {bucketRightEdges: [0], bucketCounts: [1, 2], min: 1, max: 2}, 1, 2, + 2); + }, 'Edges and counts are of different lengths.'); }); - describe('Verify that the histogram format conversion works.', function() { - - function assertHistogramEquality(h1, h2) { - h1.forEach(function(b1, i) { - let b2 = h2[i]; - assert.closeTo(b1.x, b2.x, 1e-10); - assert.closeTo(b1.dx, b2.dx, 1e-10); - assert.closeTo(b1.y, b2.y, 1e-10); - }); - } - - it('Throws and error if the inputs are of different lengths', function() { - assert.throws(function() { + it('Handles data with no bins', () => { + chai.assert.deepEqual( convertBins( - {bucketRightEdges: [0], bucketCounts: [1, 2], min: 1, max: 2}, 1, 2, - 2); - }, 'Edges and counts are of different lengths.'); - }); - - it('Handles data with no bins', function() { - assert.deepEqual( - convertBins( - {bucketRightEdges: [], bucketCounts: [], min: 0, max: 0}, 0, 0, - 0), - []); - }); - - it('Handles data with one bin', function() { - let counts = [1]; - let rightEdges = [1.21e-12]; - let histogram = [{x: 1.1e-12, dx: 1.21e-12 - 1.1e-12, y: 1}]; - let newHistogram = convertBins( - { - bucketRightEdges: rightEdges, - bucketCounts: counts, - min: 1.1e-12, - max: 1.21e-12 - }, - 1.1e-12, 1.21e-12, 1); - assertHistogramEquality(newHistogram, histogram); - }); - - it('Handles data with two bins.', function() { - let counts = [1, 2]; - let rightEdges = [1.1e-12, 1.21e-12]; - let histogram = [ - {x: 1.0e-12, dx: 1.05e-13, y: 1.09090909090909}, - {x: 1.105e-12, dx: 1.05e-13, y: 1.9090909090909} - ]; - let newHistogram = convertBins( - { - bucketRightEdges: rightEdges, - bucketCounts: counts, - min: 1.0e-12, - max: 1.21e-12 - }, - 1.0e-12, 1.21e-12, 2); - assertHistogramEquality(newHistogram, histogram); - }); - - it('Handles a domain that crosses zero, but doesn\'t include zero as ' + - 'an edge.', - function() { - let counts = [1, 2]; - let rightEdges = [-1.0e-12, 1.0e-12]; - let histogram = [ - {x: -1.1e-12, dx: 1.05e-12, y: 1.95}, - {x: -0.5e-13, dx: 1.05e-12, y: 1.05} - ]; - let newHistogram = convertBins( - { - bucketRightEdges: rightEdges, - bucketCounts: counts, - min: -1.1e-12, - max: 1.0e-12 - }, - -1.1e-12, 1.0e-12, 2); - assertHistogramEquality(newHistogram, histogram); - }); - - it('Handles a histogram of all zeros', function() { - let h = { - min: 0, - max: 0, - nItems: 51200, - sum: 0, - sumSquares: 0, - bucketRightEdges: [0, 1e-12, 1.7976931348623157e+308], - bucketCounts: [0, 51200, 0], - wall_time: '2017-01-25T02:30:11.257Z', - step: 0 - }; - let newHistogram = convertBins(h, 0, 0, 5); - let expectedHistogram = [ - {x: -1, dx: 0.4, y: 0}, {x: -0.6, dx: 0.4, y: 0}, - {x: -0.2, dx: 0.4, y: 51200}, {x: 0.2, dx: 0.4, y: 0}, - {x: 0.6, dx: 0.4, y: 0} - ]; - assertHistogramEquality(newHistogram, expectedHistogram); - }); - - it('Handles a right-most right edge that extends to very large number.', - function() { - let counts = [1, 2, 3]; - let rightEdges = [0, 1.0e-12, 1.0e14]; - let histogram = [ - {x: -1.0e-12, dx: 0.7e-12, y: 0.7}, - {x: -0.3e-12, dx: 0.7e-12, y: 1.1}, - {x: 0.4e-12, dx: 0.7e-12, y: 4.2} - ]; - let newHistogram = convertBins( - { - bucketRightEdges: rightEdges, - bucketCounts: counts, - min: -1.0e-12, - max: 1.1e-12 - }, - -1.0e-12, 1.1e-12, 3); - assertHistogramEquality(newHistogram, histogram); - }); + {bucketRightEdges: [], bucketCounts: [], min: 0, max: 0}, 0, 0, 0), + []); }); -} + + it('Handles data with one bin', () => { + const counts = [1]; + const rightEdges = [1.21e-12]; + const histogram = [{x: 1.1e-12, dx: 1.21e-12 - 1.1e-12, y: 1}]; + const newHistogram = convertBins( + { + bucketRightEdges: rightEdges, + bucketCounts: counts, + min: 1.1e-12, + max: 1.21e-12 + }, + 1.1e-12, 1.21e-12, 1); + assertHistogramEquality(newHistogram, histogram); + }); + + it('Handles data with two bins.', () => { + const counts = [1, 2]; + const rightEdges = [1.1e-12, 1.21e-12]; + const histogram = [ + {x: 1.0e-12, dx: 1.05e-13, y: 1.09090909090909}, + {x: 1.105e-12, dx: 1.05e-13, y: 1.9090909090909} + ]; + const newHistogram = convertBins( + { + bucketRightEdges: rightEdges, + bucketCounts: counts, + min: 1.0e-12, + max: 1.21e-12 + }, + 1.0e-12, 1.21e-12, 2); + assertHistogramEquality(newHistogram, histogram); + }); + + it('Handles a domain that crosses zero, but doesn\'t include zero as ' + + 'an edge.', + () => { + const counts = [1, 2]; + const rightEdges = [-1.0e-12, 1.0e-12]; + const histogram = [ + {x: -1.1e-12, dx: 1.05e-12, y: 1.95}, + {x: -0.5e-13, dx: 1.05e-12, y: 1.05} + ]; + const newHistogram = convertBins( + { + bucketRightEdges: rightEdges, + bucketCounts: counts, + min: -1.1e-12, + max: 1.0e-12 + }, + -1.1e-12, 1.0e-12, 2); + assertHistogramEquality(newHistogram, histogram); + }); + + it('Handles a histogram of all zeros', () => { + const h = { + min: 0, + max: 0, + nItems: 51200, + sum: 0, + sumSquares: 0, + bucketRightEdges: [0, 1e-12, 1.7976931348623157e+308], + bucketCounts: [0, 51200, 0], + wall_time: '2017-01-25T02:30:11.257Z', + step: 0 + }; + const newHistogram = convertBins(h, 0, 0, 5); + const expectedHistogram = [ + {x: -1, dx: 0.4, y: 0}, {x: -0.6, dx: 0.4, y: 0}, + {x: -0.2, dx: 0.4, y: 51200}, {x: 0.2, dx: 0.4, y: 0}, + {x: 0.6, dx: 0.4, y: 0} + ]; + assertHistogramEquality(newHistogram, expectedHistogram); + }); + + it('Handles a right-most right edge that extends to very large number.', + () => { + const counts = [1, 2, 3]; + const rightEdges = [0, 1.0e-12, 1.0e14]; + const histogram = [ + {x: -1.0e-12, dx: 0.7e-12, y: 0.7}, {x: -0.3e-12, dx: 0.7e-12, y: 1.1}, + {x: 0.4e-12, dx: 0.7e-12, y: 4.2} + ]; + const newHistogram = convertBins( + { + bucketRightEdges: rightEdges, + bucketCounts: counts, + min: -1.0e-12, + max: 1.1e-12 + }, + -1.0e-12, 1.1e-12, 3); + assertHistogramEquality(newHistogram, histogram); + }); +}); diff --git a/tensorflow/tensorboard/components/tf_backend/test/behaviorTests.ts b/tensorflow/tensorboard/components/tf_backend/test/behaviorTests.ts index 42b6fad7fe8..6bf328140e2 100644 --- a/tensorflow/tensorboard/components/tf_backend/test/behaviorTests.ts +++ b/tensorflow/tensorboard/components/tf_backend/test/behaviorTests.ts @@ -12,147 +12,154 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -var assert = chai.assert; + +import {Backend, getRuns, getTags, RunToTag} from '../backend' +import {BackendBehavior} from '../behavior' + declare function fixture(id: string): void; - module TF.Backend { - window.addEventListener('WebComponentsReady', function() { - Polymer({ - is: 'test-element', - behaviors: [TF.Backend.Behavior], - frontendReload: function() { - // no-op - }, - }); +window.addEventListener('WebComponentsReady', function() { + Polymer({ + is: 'test-element', + behaviors: [BackendBehavior], + frontendReload: function() { + // no-op + }, + }); +}); + +describe('data-behavior', function() { + let testElement; + let resolve; + let reject; + const fakeBackend = { + scalarTags() { + return new Promise((_resolve, _reject) => { + resolve = (x) => _resolve(x); + reject = (x) => _reject(x); }); + }, + scalar(x) { + return this; + }, + }; + beforeEach(function() { + testElement = fixture('testElementFixture'); + testElement.autoLoad = false; + testElement.backend = fakeBackend; + testElement.dataType = 'scalar'; + }); - describe('data-behavior', function() { - var testElement; - var resolve; - var reject; - var fakeBackend = { - scalarRuns: function() { - return new Promise(function(_resolve, _reject) { - resolve = (x) => _resolve(x); - reject = (x) => _reject(x); - }); - }, - scalar: function(x) { return this; }, - }; - beforeEach(function() { - testElement = fixture('testElementFixture'); - testElement.autoLoad = false; - testElement.backend = fakeBackend; - testElement.dataType = 'scalar'; + it('load states work as expected', function(done) { + chai.assert.equal(testElement.loadState, 'noload'); + var reloaded = testElement.reload(); + chai.assert.equal(testElement.loadState, 'pending'); + resolve(); + reloaded + .then(function() { + chai.assert.equal(testElement.loadState, 'loaded'); + var reloaded2 = testElement.reload(); + chai.assert.equal(testElement.loadState, 'pending'); + reject(); + return reloaded2; + }) + .then(function() { + chai.assert.equal(testElement.loadState, 'failure'); + done(); }); + }); - it('load states work as expected', function(done) { - assert.equal(testElement.loadState, 'noload'); - var reloaded = testElement.reload(); - assert.equal(testElement.loadState, 'pending'); - resolve(); - reloaded - .then(function() { - assert.equal(testElement.loadState, 'loaded'); - var reloaded2 = testElement.reload(); - assert.equal(testElement.loadState, 'pending'); - reject(); - return reloaded2; - }) - .then(function() { - assert.equal(testElement.loadState, 'failure'); - done(); - }); - }); + it('data provider set appropriately', function() { + chai.assert.deepEqual(testElement.dataProvider(), testElement.backend); + }); - it('data provider set appropriately', function() { - assert.deepEqual(testElement.dataProvider(), testElement.backend); - }); + it('loads data as expected', function(done) { + var r2t: RunToTag = { + run1: ['foo', 'bar', 'zod'], + run2: ['zoink', 'zow'], + run3: ['.'], + }; + var tags = getTags(r2t); + var runs = getRuns(r2t); + testElement.backend = fakeBackend; + testElement.dataType = 'scalar'; + testElement.reload().then(function(x) { + chai.assert.deepEqual(testElement.run2tag, r2t); + chai.assert.deepEqual(testElement.runs, runs); + chai.assert.deepEqual(testElement.tags, tags); + done(); + }); + resolve(r2t); + }); - it('loads data as expected', function(done) { - var r2t: RunToTag = { - run1: ['foo', 'bar', 'zod'], - run2: ['zoink', 'zow'], - run3: ['.'], - }; - var tags = TF.Backend.getTags(r2t); - var runs = TF.Backend.getRuns(r2t); - testElement.backend = fakeBackend; - testElement.dataType = 'scalar'; - testElement.reload().then(function(x) { - assert.deepEqual(testElement.run2tag, r2t); - assert.deepEqual(testElement.runs, runs); - assert.deepEqual(testElement.tags, tags); - done(); - }); - resolve(r2t); - }); + it('errors thrown on bad data types', function() { + testElement.backend = undefined; + chai.assert.throws(function() { + testElement.dataType = 'foo'; + }); + testElement.dataType = 'scalar'; + testElement.dataType = 'graph'; + testElement.dataType = 'histogram'; + }); - it('errors thrown on bad data types', function() { - testElement.backend = undefined; - assert.throws(function() { testElement.dataType = 'foo'; }); - testElement.dataType = 'scalar'; - testElement.dataType = 'graph'; - testElement.dataType = 'histogram'; - }); - - it('dataNotFound flag works', function(done) { - assert.isFalse(testElement.dataNotFound, 'initially false'); - var next = testElement.reload(); - assert.isFalse(testElement.dataNotFound, 'still false while pending'); - resolve({foo: [], bar: []}); - next.then(() => { - assert.isTrue(testElement.dataNotFound, 'true on empty data'); - var last = testElement.reload(); - assert.isTrue(testElement.dataNotFound, 'still true while pending'); - resolve({foo: ['bar'], bar: ['zod']}); - last.then(() => { - assert.isFalse( - testElement.dataNotFound, 'false now that we have data'); - done(); - }); - }); - }); - - it('reloads as soon as setup, if autoReload is true', function(done) { - var r2t = {foo: [], bar: []}; - var fakeBackend = { - scalarRuns: () => Promise.resolve(r2t), - scalar: () => null, - }; - testElement = fixture('testElementFixture'); - testElement.dataType = 'scalar'; - testElement.backend = fakeBackend; - setTimeout(() => { - assert.equal(testElement.run2tag, r2t); - done(); - }); - }); - - it('doesn\'t mutate props if backend returns same data', function( - done) { - var r2t_1 = {foo: ['1', '2'], bar: ['3', '4']}; - var r2t_2 = {foo: ['1', '2'], bar: ['3', '4']}; - var fakeBackend = { - scalarRuns: () => Promise.resolve(r2t_1), - scalar: () => null, - }; - testElement.backend = fakeBackend; - testElement.reload().then(() => { - fakeBackend.scalarRuns = () => Promise.resolve(r2t_2); - var tags = testElement.tags; - testElement.reload().then(() => { - // shallow equality ensures it wasn't recomputed - assert.equal(tags, testElement.tags, 'tags was not recomputed'); - done(); - }); - }); - - it('reload calls frontendReload', function(done) { - testElement.frontendReload = function() { done(); }; - testElement.reload(); - }); - - }); + it('dataNotFound flag works', function(done) { + chai.assert.isFalse(testElement.dataNotFound, 'initially false'); + var next = testElement.reload(); + chai.assert.isFalse(testElement.dataNotFound, 'still false while pending'); + resolve({foo: [], bar: []}); + next.then(() => { + chai.assert.isTrue(testElement.dataNotFound, 'true on empty data'); + var last = testElement.reload(); + chai.assert.isTrue(testElement.dataNotFound, 'still true while pending'); + resolve({foo: ['bar'], bar: ['zod']}); + last.then(() => { + chai.assert.isFalse( + testElement.dataNotFound, 'false now that we have data'); + done(); }); - } + }); + }); + + it('reloads as soon as setup, if autoReload is true', function(done) { + var r2t = {foo: [], bar: []}; + var fakeBackend = { + scalarTags: () => Promise.resolve(r2t), + scalar: () => null, + }; + testElement = fixture('testElementFixture'); + testElement.dataType = 'scalar'; + testElement.backend = fakeBackend; + setTimeout(() => { + chai.assert.equal(testElement.run2tag, r2t); + done(); + }); + }); + + it('doesn\'t mutate props if backend returns same data', function(done) { + var r2t_1 = {foo: ['1', '2'], bar: ['3', '4']}; + var r2t_2 = {foo: ['1', '2'], bar: ['3', '4']}; + var fakeBackend = { + scalarTags: () => Promise.resolve(r2t_1), + scalar: () => null, + }; + testElement.backend = fakeBackend; + testElement.reload().then(() => { + fakeBackend.scalarTags = () => Promise.resolve(r2t_2); + var tags = testElement.tags; + testElement.reload().then(() => { + // shallow equality ensures it wasn't recomputed + chai.assert.equal(tags, testElement.tags, 'tags was not recomputed'); + done(); + }); + }); + }); + + // TODO(dandelion): Fix this test. + it('reload calls frontendReload', function(done) { + testElement.frontendReload = function() { + done(); + }; + testElement.reload(); + }); + +}); diff --git a/tensorflow/tensorboard/components/tf_backend/test/index.html b/tensorflow/tensorboard/components/tf_backend/test/index.html deleted file mode 100644 index 7f51861d25a..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/index.html +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_backend/test/requestManagerTest.ts b/tensorflow/tensorboard/components/tf_backend/test/requestManagerTest.ts deleted file mode 100644 index b93e1569a45..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/requestManagerTest.ts +++ /dev/null @@ -1,287 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -var assert = chai.assert; - -module TF.Backend { - interface MockRequest { - resolve: Function; - reject: Function; - id: number; - url: string; - } - - class MockedRequestManager extends TF.Backend.RequestManager { - private resolvers: Function[]; - private rejectors: Function[]; - public requestsDispatched: number; - - constructor(maxRequests = 10, maxRetries = 3) { - super(maxRequests, maxRetries); - this.resolvers = []; - this.rejectors = []; - this.requestsDispatched = 0; - } - - protected _promiseFromUrl(url) { - return new Promise((resolve, reject) => { - var mockJSON = { - ok: true, - json: function() { return url; }, - url: url, - status: 200, - }; - var mockFailedRequest: any = { - ok: false, - url: url, - status: 502, - }; - var mockFailure = new RequestNetworkError(mockFailedRequest, url); - this.resolvers.push(function() { resolve(mockJSON); }); - this.rejectors.push(function() { reject(mockFailure); }); - this.requestsDispatched++; - }); - } - - public resolveFakeRequest() { - this.resolvers.pop()(); - } - - public rejectFakeRequest() { - this.rejectors.pop()(); - } - - public dispatchAndResolve() { - // Wait for at least one request to be dispatched, then resolve it. - this.waitForDispatch(1).then(() => this.resolveFakeRequest()); - } - - public waitForDispatch(num) { - return waitForCondition(() => {return this.requestsDispatched >= num; }); - } - } - - /* Create a promise that returns when *check* returns true. */ - // May cause a test timeout if check never becomes true. - function waitForCondition(check: () => boolean): Promise { - return new Promise((resolve, reject) => { - var go = function() { - if (check()) { - resolve(); - } - setTimeout(go, 2); - }; - go(); - }); - } - - describe('backend', () => { - describe('request manager', () => { - it('request loads JSON properly', (done) => { - var rm = new TF.Backend.RequestManager(); - var promise = rm.request('data/example.json'); - promise.then( - (response) => { - assert.deepEqual(response, {foo: 3, bar: 'zoidberg'}); - done(); - }, - (reject) => { throw new Error(reject); }); - }); - - it('rejects on bad url', (done) => { - var rm = new TF.Backend.RequestManager(5, 0); - var bad_url = '_bad_url_which_doesnt_exist.json'; - var promise = rm.request(bad_url); - promise.then( - (success) => { - done(new Error('the promise should have rejected')); - }, - (reject: TF.Backend.RequestNetworkError) => { - assert.instanceOf(reject, TF.Backend.RequestNetworkError); - assert.include(reject.message, '404'); - assert.include(reject.message, bad_url); - assert.equal(reject.req.status, 404); - done(); - }); - }); - - it('can retry if requests fail', (done) => { - var rm = new MockedRequestManager(3, 5); - var r = rm.request('foo'); - rm.waitForDispatch(1).then(() => { - rm.rejectFakeRequest(); - return rm.waitForDispatch(2); - }).then(() => rm.resolveFakeRequest()); - r.then((success) => done()); - }); - - it('retries at most maxRetries times', (done) => { - var MAX_RETRIES = 2; - var rm = new MockedRequestManager(3, MAX_RETRIES); - var r = rm.request('foo'); - rm.waitForDispatch(1).then(() => { - rm.rejectFakeRequest(); - return rm.waitForDispatch(2); - }).then(() => { - rm.rejectFakeRequest(); - return rm.waitForDispatch(3); - }).then(() => { - rm.rejectFakeRequest(); - }); - - r.then( - (success) => done(new Error('The reqest should have failed')), - (failure) => done()); - }); - - it('requestManager only sends maxRequests requests at a time', (done) => { - var rm = new MockedRequestManager(3); - var requestsConcluded = 0; - var r0 = rm.request('1'); - var r1 = rm.request('2'); - var r2 = rm.request('3'); - var r3 = rm.request('4'); - assert.equal(rm.activeRequests(), 3, 'three requests are active'); - assert.equal(rm.outstandingRequests(), 4, 'four requests are pending'); - rm.waitForDispatch(3) - .then(() => { - assert.equal( - rm.activeRequests(), 3, - 'three requests are still active (1)'); - assert.equal( - rm.requestsDispatched, 3, 'three requests were dispatched'); - rm.resolveFakeRequest(); - return rm.waitForDispatch(4); - }) - .then(() => { - assert.equal( - rm.activeRequests(), 3, - 'three requests are still active (2)'); - assert.equal( - rm.requestsDispatched, 4, 'four requests were dispatched'); - assert.equal( - rm.outstandingRequests(), 3, 'three requests are pending'); - rm.resolveFakeRequest(); - rm.resolveFakeRequest(); - rm.resolveFakeRequest(); - return r3; - }) - .then(() => { - assert.equal(rm.activeRequests(), 0, 'all requests finished'); - assert.equal(rm.outstandingRequests(), 0, 'no requests pending'); - done(); - }); - }); - - it('queue continues after failures', (done) => { - var rm = new MockedRequestManager(1, 0); - var r0 = rm.request('1'); - var r1 = rm.request('2'); - rm.waitForDispatch(1).then(() => { - rm.rejectFakeRequest(); - }); - - r0.then( - (success) => done(new Error('r0 should have failed')), - (failure) => 'unused_argument') - .then(() => rm.resolveFakeRequest()); - - // When the first request rejects, it should decrement nActiveRequests - // and then launch remaining requests in queue (i.e. this one) - r1.then((success) => done(), - (failure) => done(new Error(failure))); - }); - - it('queue is LIFO', (done) => { - /* This test is a bit tricky. - * We want to verify that the RequestManager queue has LIFO semantics. - * So we construct three requests off the bat: A, B, C. - * So LIFO semantics ensure these will resolve in order A, C, B. - * (Because the A request launches immediately when we create it, it's - * not in queue) - * Then after resolving A, C moves out of queue, and we create X. - * So expected final order is A, C, X, B. - * We verify this with an external var that counts how many requests were - * resolved. - */ - var rm = new MockedRequestManager(1); - var nResolved = 0; - function assertResolutionOrder(expectedSpotInSequence) { - return function() { - nResolved++; - assert.equal(expectedSpotInSequence, nResolved); - }; - } - - function launchThirdRequest() { - rm.request('started late but goes third') - .then(assertResolutionOrder(3)) - .then(() => rm.dispatchAndResolve()); - } - - rm.request('first') - .then(assertResolutionOrder( - 1)) // Assert that this one resolved first - .then(launchThirdRequest) - .then(() => rm.dispatchAndResolve()); // then trigger the next one - - rm.request('this one goes fourth') // created second, will go last - .then(assertResolutionOrder( - 4)) // assert it was the fourth to get resolved - .then(done); // finish the test - - rm.request('second') - .then(assertResolutionOrder(2)) - .then(() => rm.dispatchAndResolve()); - - rm.dispatchAndResolve(); - }); - - it('requestManager can clear queue', (done) => { - var rm = new MockedRequestManager(1); - var requestsResolved = 0; - var requestsRejected = 0; - var success = () => requestsResolved++; - var failure = (err) => { - assert.equal(err.name, 'RequestCancellationError'); - requestsRejected++; - }; - var finishTheTest = () => { - assert.equal(rm.activeRequests(), 0, 'no requests still active'); - assert.equal( - rm.requestsDispatched, 1, 'only one req was ever dispatched'); - assert.equal(rm.outstandingRequests(), 0, 'no pending requests'); - assert.equal(requestsResolved, 1, 'one request got resolved'); - assert.equal( - requestsRejected, 4, 'four were cancelled and threw errors'); - done(); - }; - rm.request('0').then(success, failure).then(finishTheTest); - rm.request('1').then(success, failure); - rm.request('2').then(success, failure); - rm.request('3').then(success, failure); - rm.request('4').then(success, failure); - assert.equal(rm.activeRequests(), 1, 'one req is active'); - rm.waitForDispatch(1).then(() => { - assert.equal(rm.activeRequests(), 1, 'one req is active'); - assert.equal(rm.requestsDispatched, 1, 'one req was dispatched'); - assert.equal(rm.outstandingRequests(), 5, 'five reqs outstanding'); - rm.clearQueue(); - rm.resolveFakeRequest(); - // resolving the first request triggers finishTheTest - }); - }); - }); - }); -} diff --git a/tensorflow/tensorboard/components/tf_backend/test/requestManagerTests.ts b/tensorflow/tensorboard/components/tf_backend/test/requestManagerTests.ts new file mode 100644 index 00000000000..3800e6e4021 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_backend/test/requestManagerTests.ts @@ -0,0 +1,294 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {RequestManager, RequestNetworkError} from '../requestManager'; + +interface MockRequest { + resolve: Function; + reject: Function; + id: number; + url: string; +} + +class MockedRequestManager extends RequestManager { + private resolvers: Function[]; + private rejectors: Function[]; + public requestsDispatched: number; + constructor(maxRequests = 10, maxRetries = 3) { + super(maxRequests, maxRetries); + this.resolvers = []; + this.rejectors = []; + this.requestsDispatched = 0; + } + protected _promiseFromUrl(url) { + return new Promise((resolve, reject) => { + const mockJSON = { + ok: true, + json() { + return url; + }, + url, + status: 200, + }; + const mockFailedRequest: any = { + ok: false, + url, + status: 502, + }; + const mockFailure = new RequestNetworkError(mockFailedRequest, url); + this.resolvers.push(() => { + resolve(mockJSON); + }); + this.rejectors.push(() => { + reject(mockFailure); + }); + this.requestsDispatched++; + }); + } + public resolveFakeRequest() { + this.resolvers.pop()(); + } + public rejectFakeRequest() { + this.rejectors.pop()(); + } + public dispatchAndResolve() { + // Wait for at least one request to be dispatched, then resolve it. + this.waitForDispatch(1).then(() => this.resolveFakeRequest()); + } + public waitForDispatch(num) { + return waitForCondition(() => { + return this.requestsDispatched >= num; + }); + } +} + +/** Create a promise that returns when *check* returns true. + * May cause a test timeout if check never becomes true. + */ + +function waitForCondition(check: () => boolean): Promise { + return new Promise((resolve, reject) => { + const go = () => { + if (check()) { + resolve(); + } + setTimeout(go, 2); + }; + go(); + }); +} + +describe('backend', () => { + describe('request manager', () => { + it('request loads JSON properly', (done) => { + const rm = new RequestManager(); + const promise = rm.request('data/example.json'); + promise.then( + (response) => { + chai.assert.deepEqual(response, {foo: 3, bar: 'zoidberg'}); + done(); + }, + (reject) => { + throw new Error(reject); + }); + }); + + it('rejects on bad url', (done) => { + const rm = new RequestManager(5, 0); + const badUrl = '_bad_url_which_doesnt_exist.json'; + const promise = rm.request(badUrl); + promise.then( + (success) => { + done(new Error('the promise should have rejected')); + }, + (reject: RequestNetworkError) => { + chai.assert.include(reject.message, '404'); + chai.assert.include(reject.message, badUrl); + chai.assert.equal(reject.req.status, 404); + done(); + }); + }); + + it('can retry if requests fail', (done) => { + const rm = new MockedRequestManager(3, 5); + const r = rm.request('foo'); + rm.waitForDispatch(1) + .then(() => { + rm.rejectFakeRequest(); + return rm.waitForDispatch(2); + }) + .then(() => rm.resolveFakeRequest()); + r.then((success) => done()); + }); + + it('retries at most maxRetries times', (done) => { + const MAX_RETRIES = 2; + const rm = new MockedRequestManager(3, MAX_RETRIES); + const r = rm.request('foo'); + rm.waitForDispatch(1) + .then(() => { + rm.rejectFakeRequest(); + return rm.waitForDispatch(2); + }) + .then(() => { + rm.rejectFakeRequest(); + return rm.waitForDispatch(3); + }) + .then(() => { + rm.rejectFakeRequest(); + }); + + r.then( + (success) => done(new Error('The request should have failed')), + (failure) => done()); + }); + + it('requestManager only sends maxRequests requests at a time', (done) => { + const rm = new MockedRequestManager(3); + const r0 = rm.request('1'); + const r1 = rm.request('2'); + const r2 = rm.request('3'); + const r3 = rm.request('4'); + chai.assert.equal(rm.activeRequests(), 3, 'three requests are active'); + chai.assert.equal( + rm.outstandingRequests(), 4, 'four requests are pending'); + rm.waitForDispatch(3) + .then(() => { + chai.assert.equal( + rm.activeRequests(), 3, 'three requests are still active (1)'); + chai.assert.equal( + rm.requestsDispatched, 3, 'three requests were dispatched'); + rm.resolveFakeRequest(); + return rm.waitForDispatch(4); + }) + .then(() => { + chai.assert.equal( + rm.activeRequests(), 3, 'three requests are still active (2)'); + chai.assert.equal( + rm.requestsDispatched, 4, 'four requests were dispatched'); + chai.assert.equal( + rm.outstandingRequests(), 3, 'three requests are pending'); + rm.resolveFakeRequest(); + rm.resolveFakeRequest(); + rm.resolveFakeRequest(); + return r3; + }) + .then(() => { + chai.assert.equal(rm.activeRequests(), 0, 'all requests finished'); + chai.assert.equal( + rm.outstandingRequests(), 0, 'no requests pending'); + done(); + }); + }); + + it('queue continues after failures', (done) => { + const rm = new MockedRequestManager(1, 0); + const r0 = rm.request('1'); + const r1 = rm.request('2'); + rm.waitForDispatch(1).then(() => { + rm.rejectFakeRequest(); + }); + + r0.then( + (success) => done(new Error('r0 should have failed')), + (failure) => 'unused_argument') + .then(() => rm.resolveFakeRequest()); + + // When the first request rejects, it should decrement nActiveRequests + // and then launch remaining requests in queue (i.e. this one) + r1.then((success) => done(), (failure) => done(new Error(failure))); + }); + + it('queue is LIFO', (done) => { + /* This test is a bit tricky. + * We want to verify that the RequestManager queue has LIFO semantics. + * So we construct three requests off the bat: A, B, C. + * So LIFO semantics ensure these will resolve in order A, C, B. + * (Because the A request launches immediately when we create it, it's + * not in queue) + * Then after resolving A, C moves out of queue, and we create X. + * So expected final order is A, C, X, B. + * We verify this with an external var that counts how many requests were + * resolved. + */ + const rm = new MockedRequestManager(1); + let nResolved = 0; + function assertResolutionOrder(expectedSpotInSequence) { + return () => { + nResolved++; + chai.assert.equal(expectedSpotInSequence, nResolved); + }; + } + + function launchThirdRequest() { + rm.request('started late but goes third') + .then(assertResolutionOrder(3)) + .then(() => rm.dispatchAndResolve()); + } + + rm.request('first') + .then( + assertResolutionOrder(1)) // Assert that this one resolved first + .then(launchThirdRequest) + .then(() => rm.dispatchAndResolve()); // then trigger the next one + + rm.request('this one goes fourth') // created second, will go last + .then(assertResolutionOrder( + 4)) // assert it was the fourth to get resolved + .then(done); // finish the test + + rm.request('second') + .then(assertResolutionOrder(2)) + .then(() => rm.dispatchAndResolve()); + + rm.dispatchAndResolve(); + }); + + it('requestManager can clear queue', (done) => { + const rm = new MockedRequestManager(1); + let requestsResolved = 0; + let requestsRejected = 0; + const success = () => requestsResolved++; + const failure = (err) => { + chai.assert.equal(err.name, 'RequestCancellationError'); + requestsRejected++; + }; + const finishTheTest = () => { + chai.assert.equal(rm.activeRequests(), 0, 'no requests still active'); + chai.assert.equal( + rm.requestsDispatched, 1, 'only one req was ever dispatched'); + chai.assert.equal(rm.outstandingRequests(), 0, 'no pending requests'); + chai.assert.equal(requestsResolved, 1, 'one request got resolved'); + chai.assert.equal( + requestsRejected, 4, 'four were cancelled and threw errors'); + done(); + }; + rm.request('0').then(success, failure).then(finishTheTest); + rm.request('1').then(success, failure); + rm.request('2').then(success, failure); + rm.request('3').then(success, failure); + rm.request('4').then(success, failure); + chai.assert.equal(rm.activeRequests(), 1, 'one req is active'); + rm.waitForDispatch(1).then(() => { + chai.assert.equal(rm.activeRequests(), 1, 'one req is active'); + chai.assert.equal(rm.requestsDispatched, 1, 'one req was dispatched'); + chai.assert.equal(rm.outstandingRequests(), 5, 'five reqs outstanding'); + rm.clearQueue(); + rm.resolveFakeRequest(); + // resolving the first request triggers finishTheTest + }); + }); + }); +}); diff --git a/tensorflow/tensorboard/components/tf_backend/test/tests.html b/tensorflow/tensorboard/components/tf_backend/test/tests.html new file mode 100644 index 00000000000..58cb89a30b6 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_backend/test/tests.html @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_backend/tf-backend.html b/tensorflow/tensorboard/components/tf_backend/tf-backend.html index 0e07c7fdb1e..c2a44b3b63f 100644 --- a/tensorflow/tensorboard/components/tf_backend/tf-backend.html +++ b/tensorflow/tensorboard/components/tf_backend/tf-backend.html @@ -23,5 +23,6 @@ limitations under the License. + diff --git a/tensorflow/tensorboard/components/tf_backend/urlPathHelpers.ts b/tensorflow/tensorboard/components/tf_backend/urlPathHelpers.ts index 7c59eafb448..62519dac5ca 100644 --- a/tensorflow/tensorboard/components/tf_backend/urlPathHelpers.ts +++ b/tensorflow/tensorboard/components/tf_backend/urlPathHelpers.ts @@ -12,31 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module TF.Backend { - export var BAD_CHARACTERS = '#%&{}\\/<>*? $!\'":@+`|=() '; - /** Cleanup a url so that it can be loaded from a filesystem. */ - export function demoify(s) { - // for consistency with python's urllib.urlencode - s = s.replace(new RegExp('%20', 'g'), '+'); - for (var i = 0; i < BAD_CHARACTERS.length; i++) { - var c = BAD_CHARACTERS[i]; - s = s.replace(new RegExp('\\' + c, 'g'), '_'); - } - return s; - } - - export function queryEncoder(params?: any): string { - // It's important that the keys be sorted, so we always grab the right file - // if we are talking to the backend generated by serialze_tensorboard.py - if (params == null) { - return ''; - } - var components = _.keys(params) - .sort() - .filter((k) => params[k] !== undefined) - .map((k) => k + '=' + encodeURIComponent(params[k])); - var result = components.length ? '?' + components.join('&') : ''; - // Replace parens for consistency with urllib.urlencode - return result.replace(/\(/g, '%28').replace(/\)/g, '%29'); +export const BAD_CHARACTERS = '#%&{}\\/<>*? $!\'":@+`|=() '; +/** Cleanup a url so that it can be loaded from a filesystem. */ +export function demoify(s) { + // for consistency with python's urllib.urlencode + s = s.replace(new RegExp('%20', 'g'), '+'); + for (let i = 0; i < BAD_CHARACTERS.length; i++) { + const c = BAD_CHARACTERS[i]; + s = s.replace(new RegExp('\\' + c, 'g'), '_'); } + return s; +} + +export function queryEncoder(params?: any): string { + // It's important that the keys be sorted, so we always grab the right file + // if we are talking to the backend generated by serialze_tensorboard.py + if (params == null) { + return ''; + } + const components = _.keys(params) + .sort() + .filter((k) => params[k] !== undefined) + .map((k) => k + '=' + encodeURIComponent(params[k])); + const result = components.length ? '?' + components.join('&') : ''; + // Replace parens for consistency with urllib.urlencode + return result.replace(/\(/g, '%28').replace(/\)/g, '%29'); } diff --git a/tensorflow/tensorboard/components/tf_color_scale/BUILD b/tensorflow/tensorboard/components/tf_color_scale/BUILD new file mode 100644 index 00000000000..730ab37d6f7 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_color_scale/BUILD @@ -0,0 +1,39 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "tf_color_scale", + srcs = [ + "colorScale.ts", + "palettes.ts", + "tf-color-scale.html", + ], + path = "/tf-color-scale", + deps = [ + "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:polymer", + ], +) + +ts_web_library( + name = "demo", + srcs = ["index.html"], + path = "/tf-color-scale", + deps = [ + ":tf_color_scale", + "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_button", + "@org_polymer_paper_styles", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts b/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts index c05d9765335..e20a65cdd84 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts +++ b/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts @@ -15,49 +15,75 @@ limitations under the License. // Example usage: // runs = ["train", "test", "test1", "test2"] -// ccs = new TF.ColorScale(); +// ccs = new ColorScale(); // ccs.domain(runs); // ccs.getColor("train"); // ccs.getColor("test1"); -module TF { - export class ColorScale { - private palette: string[]; - private identifiers = d3.map(); +import {palettes} from './palettes'; - /** - * Creates a color scale with optional custom palette. - * @param {string[]} [palette=TF.palettes.googleColorBlind] - The color - * palette you want as an Array of hex strings. - */ - constructor(palette: string[] = TF.palettes.googleColorBlindAssist) { - this.palette = palette; - } +export class ColorScale { + private identifiers = d3.map(); - /** - * Set the domain of strings. - * @param {string[]} strings - An array of possible strings to use as the - * domain for your scale. - */ - public domain(strings: string[]): this { - this.identifiers = d3.map(); - strings.forEach((s, i) => { - this.identifiers.set(s, this.palette[i % this.palette.length]); - }); - return this; - } + /** + * Creates a color scale with optional custom palette. + * @param {Array} [palette=palettes.googleColorBlind] - The color + * palette you want as an Array of hex strings. + */ + constructor( + private readonly palette: string[] = palettes.googleColorBlindAssist) {} - /** - * Use the color scale to transform an element in the domain into a color. - * @param {string} The input string to map to a color. - * @return {string} The color corresponding to that input string. - * @throws Will error if input string is not in the scale's domain. - */ - public scale(s: string): string { - if (!this.identifiers.has(s)) { - throw new Error('String was not in the domain.'); - } - return this.identifiers.get(s) as string; + /** + * Set the domain of strings. + * @param {Array} strings - An array of possible strings to use as the + * domain for your scale. + */ + public domain(strings: string[]): this { + this.identifiers = d3.map(); + + // TODO(wchargin): Remove this call to `sort` once we have only a + // singleton ColorScale, linked directly to the RunsStore, which + // will always give sorted output. + strings = strings.slice(); + strings.sort(); + + strings.forEach((s, i) => { + this.identifiers.set(s, this.palette[i % this.palette.length]); + }); + return this; + } + + /** + * Use the color scale to transform an element in the domain into a color. + * @param {string} The input string to map to a color. + * @return {string} The color corresponding to that input string. + * @throws Will error if input string is not in the scale's domain. + */ + public scale(s: string): string { + if (!this.identifiers.has(s)) { + throw new Error('String was not in the domain.'); } + return this.identifiers.get(s) as string; } } + +Polymer({ + is: 'tf-color-scale', + properties: { + runs: { + type: Array, + }, + outColorScale: { + type: Object, + readOnly: true, + notify: true, + value() { + return new ColorScale(); + }, + }, + }, + observers: ['updateColorScale(runs.*)'], + updateColorScale(runsChange) { + this.outColorScale.domain(this.runs); + }, +}); diff --git a/tensorflow/tensorboard/components/tf_color_scale/index.html b/tensorflow/tensorboard/components/tf_color_scale/index.html new file mode 100644 index 00000000000..81dfab098c6 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_color_scale/index.html @@ -0,0 +1,94 @@ + + + + + +tf-color-scale demo + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_color_scale/palettes.ts b/tensorflow/tensorboard/components/tf_color_scale/palettes.ts index c53ed599ae9..ce42a115458 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/palettes.ts +++ b/tensorflow/tensorboard/components/tf_color_scale/palettes.ts @@ -13,68 +13,64 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module TF { - export const palettes = { - googleStandard: [ - '#db4437', // google red 500 - '#ff7043', // deep orange 400 - '#f4b400', // google yellow 500 - '#0f9d58', // google green 500 - '#00796b', // teal 700 - '#00acc1', // cyan 600 - '#4285f4', // google blue 500 - '#5c6bc0', // indigo 400 - '#ab47bc' // purple 400 - ], - googleCool: [ - '#9e9d24', // lime 800 - '#0f9d58', // google green 500 - '#00796b', // teal 700 - '#00acc1', // cyan 600 - '#4285f4', // google blue 500 - '#5c6bc0', // indigo 400 - '#607d8b' // blue gray 500 - ], - googleWarm: [ - '#795548', // brown 500 - '#ab47bc', // purple 400 - '#f06292', // pink 300 - '#c2185b', // pink 700 - '#db4437', // google red 500 - '#ff7043', // deep orange 400 - '#f4b400' // google yellow 700 - ], - googleColorBlindAssist: [ - '#ff7043', // orange - '#00ACC1', // dark cyan - '#AB47BC', // bright purple - '#2A56C6', // dark blue - '#0b8043', // green - '#F7CB4D', // yellow - '#c0ca33', // lime - '#5e35b1', // purple - '#A52714', // red - ], - // These palettes try to be better for color differentiation. - // https://personal.sron.nl/~pault/ - colorBlindAssist1: - ['#4477aa', '#44aaaa', '#aaaa44', '#aa7744', '#aa4455', '#aa4488'], - colorBlindAssist2: [ - '#88ccee', '#44aa99', '#117733', '#999933', '#ddcc77', '#cc6677', - '#882255', '#aa4499' - ], - colorBlindAssist3: [ - '#332288', '#6699cc', '#88ccee', '#44aa99', '#117733', '#999933', - '#ddcc77', '#cc6677', '#aa4466', '#882255', '#661100', '#aa4499' - ], - // based on this palette: http://mkweb.bcgsc.ca/biovis2012/ - colorBlindAssist4: [ - '#FF6DB6', '#920000', '#924900', '#DBD100', '#24FF24', '#006DDB', - '#490092' - ], - mldash: [ - '#E47EAD', '#F4640D', '#FAA300', '#F5E636', '#00A077', '#0077B8', - '#00B7ED' - ] - }; -} +export const palettes = { + googleStandard: [ + '#db4437', // google red 500 + '#ff7043', // deep orange 400 + '#f4b400', // google yellow 500 + '#0f9d58', // google green 500 + '#00796b', // teal 700 + '#00acc1', // cyan 600 + '#4285f4', // google blue 500 + '#5c6bc0', // indigo 400 + '#ab47bc' // purple 400 + ], + googleCool: [ + '#9e9d24', // lime 800 + '#0f9d58', // google green 500 + '#00796b', // teal 700 + '#00acc1', // cyan 600 + '#4285f4', // google blue 500 + '#5c6bc0', // indigo 400 + '#607d8b' // blue gray 500 + ], + googleWarm: [ + '#795548', // brown 500 + '#ab47bc', // purple 400 + '#f06292', // pink 300 + '#c2185b', // pink 700 + '#db4437', // google red 500 + '#ff7043', // deep orange 400 + '#f4b400' // google yellow 700 + ], + googleColorBlindAssist: [ + '#ff7043', // orange + '#00ACC1', // dark cyan + '#AB47BC', // bright purple + '#2A56C6', // dark blue + '#0b8043', // green + '#F7CB4D', // yellow + '#c0ca33', // lime + '#5e35b1', // purple + '#A52714', // red + ], + // These palettes try to be better for color differentiation. + // https://personal.sron.nl/~pault/ + colorBlindAssist1: + ['#4477aa', '#44aaaa', '#aaaa44', '#aa7744', '#aa4455', '#aa4488'], + colorBlindAssist2: [ + '#88ccee', '#44aa99', '#117733', '#999933', '#ddcc77', '#cc6677', '#882255', + '#aa4499' + ], + colorBlindAssist3: [ + '#332288', '#6699cc', '#88ccee', '#44aa99', '#117733', '#999933', '#ddcc77', + '#cc6677', '#aa4466', '#882255', '#661100', '#aa4499' + ], + // based on this palette: http://mkweb.bcgsc.ca/biovis2012/ + colorBlindAssist4: [ + '#FF6DB6', '#920000', '#924900', '#DBD100', '#24FF24', '#006DDB', '#490092' + ], + mldash: [ + '#E47EAD', '#F4640D', '#FAA300', '#F5E636', '#00A077', '#0077B8', '#00B7ED' + ] +}; diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/BUILD b/tensorflow/tensorboard/components/tf_color_scale/test/BUILD new file mode 100644 index 00000000000..331783f3c76 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_color_scale/test/BUILD @@ -0,0 +1,30 @@ +package( + default_testonly = True, + default_visibility = ["//tensorflow/tensorboard:internal"], +) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "test", + srcs = [ + "colorScaleTests.ts", + "tests.html", + ], + path = "/tf-color-scale/test", + deps = [ + "//tensorflow/tensorboard/components/tf_color_scale", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + testonly = 0, + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/colorScaleTests.ts b/tensorflow/tensorboard/components/tf_color_scale/test/colorScaleTests.ts index 700a01848b6..78824a772c3 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/test/colorScaleTests.ts +++ b/tensorflow/tensorboard/components/tf_color_scale/test/colorScaleTests.ts @@ -13,34 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module TF { - let assert = chai.assert; +let assert = chai.assert; - describe('ColorScale', function() { - let ccs: ColorScale; +import {ColorScale} from '../colorScale' - beforeEach(function() { ccs = new ColorScale(); }); +describe('ColorScale', function() { + let ccs: ColorScale; - it('Returns consistent colors', function() { - ccs.domain(['train', 'eval', 'test']); - let trainColor = ccs.scale('train'); - let trainColor2 = ccs.scale('train'); - assert.equal(trainColor, trainColor2); - }); - - it('Returns consistent colors after new domain', function() { - ccs.domain(['train', 'eval']); - let trainColor = ccs.scale('train'); - ccs.domain(['train', 'eval', 'test']); - let trainColor2 = ccs.scale('train'); - assert.equal(trainColor, trainColor2); - }); - - it('Throws an error if string is not in the domain', function() { - ccs.domain(['red', 'yellow', 'green']); - assert.throws(function() { - ccs.scale('not in domain'); - }, 'String was not in the domain.'); - }); + beforeEach(function() { + ccs = new ColorScale(); }); -} + + it('Returns consistent colors', function() { + ccs.domain(['train', 'eval', 'test']); + let trainColor = ccs.scale('train'); + let trainColor2 = ccs.scale('train'); + assert.equal(trainColor, trainColor2); + }); + + it('Returns consistent colors after new domain', function() { + ccs.domain(['train', 'eval']); + let trainColor = ccs.scale('train'); + ccs.domain(['train', 'eval', 'test']); + let trainColor2 = ccs.scale('train'); + assert.equal(trainColor, trainColor2); + }); + + it('Throws an error if string is not in the domain', function() { + ccs.domain(['red', 'yellow', 'green']); + assert.throws(function() { + ccs.scale('not in domain'); + }, 'String was not in the domain.'); + }); +}); diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/index.html b/tensorflow/tensorboard/components/tf_color_scale/test/index.html deleted file mode 100644 index 9a2a174349c..00000000000 --- a/tensorflow/tensorboard/components/tf_color_scale/test/index.html +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/tests.html b/tensorflow/tensorboard/components/tf_color_scale/test/tests.html new file mode 100644 index 00000000000..59c802d02bf --- /dev/null +++ b/tensorflow/tensorboard/components/tf_color_scale/test/tests.html @@ -0,0 +1,24 @@ + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html b/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html index 743996f6241..a325f0a04cd 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html +++ b/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html @@ -16,6 +16,7 @@ limitations under the License. --> + - + - - + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts new file mode 100644 index 00000000000..0eaf852ff13 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts @@ -0,0 +1,189 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {compareTagNames} from '../vz-sorting/sorting'; + +/** + * This module contains methods that allow sorting tags into 'categories'. + * A category contains a name and a list of tags. + * The sorting strategy is defined by a 'CustomCategorization', which contains + * 'categoryDefinitions' which are regex rules used to construct a category. + * E.g. the regex rule 'xent' will create a category called 'xent' that + * contains values whose tags match the regex. + * + * After custom categories are evaluated, the tags are sorted by a hardcoded + * fallback categorizer, which may, for example, group tags into categories + * based on their top namespace. + */ + +export interface Category { + // Categories that data is sorted into + name: string; + tags: string[]; +} + +export interface CustomCategorization { + // Defines a categorization strategy + categoryDefinitions: string[]; + fallbackCategorizer: string; + /* {'TopLevelNamespaceCategorizer', + 'LegacyUnderscoreCategorizer'} */ +} + +export interface Categorizer { + // Function that generates categories + (tags: string[]): Category[]; +} + +/* Canonical TensorFlow ops are namespaced using forward slashes. + * This fallback categorizer categorizes by the top-level namespace. + */ +export var topLevelNamespaceCategorizer: Categorizer = splitCategorizer(/\//); + +export function fallbackCategorizer(s: string): Categorizer { + switch (s) { + case 'TopLevelNamespaceCategorizer': + return topLevelNamespaceCategorizer; + default: + throw new Error('Unrecognized categorization strategy: ' + s); + } +} + +/* An 'extractor' is a function that takes a tag name, and 'extracts' a + * category name. + * This function takes an extractor, and produces a categorizer. + * Currently, it is just used for the fallbackCategorizer, but we may want to + * refactor the general categorization logic to use the concept of extractors. + */ +function extractorToCategorizer(extractor: (s: string) => string): Categorizer { + return (tags: string[]): Category[] => { + if (tags.length === 0) { + return []; + } + + // Maps between top-level name and category. We use the mapping to avoid + // duplicating categories per run. + const categoryMapping: {[key: string]: Category} = {}; + + tags.forEach((t: string) => { + const topLevel = extractor(t); + if (!categoryMapping[topLevel]) { + const newCategory = { + name: topLevel, + tags: [], + }; + categoryMapping[topLevel] = newCategory; + } + + categoryMapping[topLevel].tags.push(t); + }); + + // Sort categories into alphabetical order. + const categories = + _.map(_.keys(categoryMapping).sort(), key => categoryMapping[key]); + _.forEach(categories, (category) => { + // Sort the tags within each category. + category.tags.sort(compareTagNames); + }); + return categories; + }; +} + +function splitCategorizer(r: RegExp): Categorizer { + let extractor = (t: string) => { + return t.split(r)[0]; + }; + return extractorToCategorizer(extractor); +} + +export interface CategoryDefinition { + name: string; + matches: (t: string) => boolean; +} + +export function defineCategory(ruledef: string): CategoryDefinition { + let r = new RegExp(ruledef); + let f = function(tag: string): boolean { + return r.test(tag); + }; + return {name: ruledef, matches: f}; +} + +export function _categorizer( + rules: CategoryDefinition[], fallback: Categorizer) { + return function(tags: string[]): Category[] { + let remaining: d3.Set = d3.set(tags); + let userSpecified = rules.map((def: CategoryDefinition) => { + let tags: string[] = []; + remaining.each((t: string) => { + if (def.matches(t)) { + tags.push(t); + } + }); + let cat = {name: def.name, tags: tags.sort(compareTagNames)}; + return cat; + }); + let defaultCategories = fallback(remaining.values()); + return userSpecified.concat(defaultCategories); + }; +} + +export function categorizer(s: CustomCategorization): Categorizer { + let rules = s.categoryDefinitions.map(defineCategory); + let fallback = fallbackCategorizer(s.fallbackCategorizer); + return _categorizer(rules, fallback); +}; + +Polymer({ + is: 'tf-categorizer', + properties: { + regexes: {type: Array}, + tags: {type: Array}, + categoriesAreExclusive: {type: Boolean, value: true}, + fallbackCategorizer: { + type: String, + value: 'TopLevelNamespaceCategorizer', + }, + categorizer: { + type: Object, + computed: + 'computeCategorization(regexes.*, categoriesAreExclusive, fallbackCategorizer)', + }, + categories: { + type: Array, + value: function() { + return []; + }, + notify: true, + readOnly: true + }, + }, + observers: ['recategorize(tags.*, categorizer)'], + computeCategorization: function( + regexes, categoriesAreExclusive, fallbackCategorizer) { + var categorizationStrategy = { + categoryDefinitions: regexes.base, + categoriesAreExclusive: categoriesAreExclusive, + fallbackCategorizer: fallbackCategorizer, + }; + return categorizer(categorizationStrategy); + }, + recategorize: function() { + this.debounce('tf-categorizer-recategorize', function() { + var categories = this.categorizer(this.tags); + this._setCategories(categories); + }) + }, +}); diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html index e2530d59716..a39fb9462ba 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html @@ -16,8 +16,6 @@ limitations under the License. --> - - - + diff --git a/tensorflow/tensorboard/components/tf_globals/BUILD b/tensorflow/tensorboard/components/tf_globals/BUILD new file mode 100644 index 00000000000..c5b0cfbaa55 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_globals/BUILD @@ -0,0 +1,27 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "tf_globals", + srcs = [ + "globals.ts", + "tf-globals.html", + ], + path = "/tf-globals", +) + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_globals"], + destdir = "tf-globals", +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_globals/globals.ts b/tensorflow/tensorboard/components/tf_globals/globals.ts index 33feb26d238..fb6bb83b97f 100644 --- a/tensorflow/tensorboard/components/tf_globals/globals.ts +++ b/tensorflow/tensorboard/components/tf_globals/globals.ts @@ -13,20 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -/* tslint:disable:no-namespace */ -module TF.Globals { +// The names of TensorBoard tabs. +export const TABS = [ + 'scalars', 'images', 'audio', 'graphs', 'distributions', 'histograms', + 'embeddings', 'text' +]; - // The names of TensorBoard tabs. - export var TABS = [ - 'scalars', 'images', 'audio', 'graphs', 'distributions', 'histograms', - 'embeddings' - ]; +// If true, TensorBoard stores its hash in the URI state. +// If false, tab switching in TensorBoard will not update location hash, +// because hash updates interfere with wct_tests. +let _useHash = false; - // If true, TensorBoard stores its hash in the URI state. - // If false, tab switching in TensorBoard will not update location hash, - // because hash updates interfere with wct_tests. - export var USE_HASH = false; - - // If USE_HASH is false, FAKE_HASH holds the hash contents. - export var FAKE_HASH = ''; +export function setUseHash(shouldUseHash: boolean): void { + _useHash = shouldUseHash; +} + +export function useHash(): boolean { + return _useHash; +} + +let _fakeHash = ''; + +export function setFakeHash(h: string) { + _fakeHash = h; +} + +export function getFakeHash() { + return _fakeHash; } diff --git a/tensorflow/tensorboard/components/tf_globals/tf-globals.html b/tensorflow/tensorboard/components/tf_globals/tf-globals.html index 952979d0be9..efb8e92e080 100644 --- a/tensorflow/tensorboard/components/tf_globals/tf-globals.html +++ b/tensorflow/tensorboard/components/tf_globals/tf-globals.html @@ -15,7 +15,5 @@ See the License for the specific language governing permissions and limitations under the License. --> - - - + diff --git a/tensorflow/tensorboard/components/tf_graph/BUILD b/tensorflow/tensorboard/components/tf_graph/BUILD new file mode 100644 index 00000000000..4c0894f1925 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph/BUILD @@ -0,0 +1,56 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "tf_graph", + srcs = [ + "tf-graph.html", + "tf-graph-minimap.html", + "tf-graph-scene.html", + ], + path = "/tf-graph", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common", + "//tensorflow/tensorboard/components/tf_graph_common", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "@org_polymer_iron_flex_layout", + "@org_polymer_iron_icons", + "@org_polymer_paper_button", + "@org_polymer_paper_dropdown_menu", + "@org_polymer_paper_input", + "@org_polymer_paper_menu", + "@org_polymer_paper_radio_group", + "@org_polymer_paper_toggle_button", + "@org_polymer_paper_tooltip", + ], +) + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":tf_graph"], + destdir = "tf-graph", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", + "//tensorflow/tensorboard/components/tf_graph_common:legacy", + "//third_party/javascript/polymer/v1/iron-flex-layout:lib", + "//third_party/javascript/polymer/v1/iron-icons:lib", + "//third_party/javascript/polymer/v1/paper-button:lib", + "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", + "//third_party/javascript/polymer/v1/paper-input:lib", + "//third_party/javascript/polymer/v1/paper-menu:lib", + "//third_party/javascript/polymer/v1/paper-radio-group:lib", + "//third_party/javascript/polymer/v1/paper-toggle-button:lib", + "//third_party/javascript/polymer/v1/paper-tooltip:lib", + "//third_party/javascript/polymer/v1/polymer:lib", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph/demo/BUILD b/tensorflow/tensorboard/components/tf_graph/demo/BUILD new file mode 100644 index 00000000000..02f3bf64bbc --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph/demo/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +# bazel run //third_party/tensorflow/tensorboard/components/tf_graph/demo +ts_web_library( + name = "demo", + srcs = ["index.html"] + glob(["data/**"]), + path = "/tf-graph/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_graph", + "//tensorflow/tensorboard/components/tf_graph_common", + "//tensorflow/tensorboard/components/tf_graph_loader", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_styles", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph/demo/data/graph.pbtxt new file mode 100644 index 00000000000..30b20645346 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph/demo/data/graph.pbtxt @@ -0,0 +1,4606 @@ +node { + name: "GradientDescent/learning_rate" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_3" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.1 + } + } + } +} +node { + name: "gradients/add_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 100 + } + } + } +} +node { + name: "gradients/add_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000d\000\000\000" + } + } + } +} +node { + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } +} +node { + name: "gradients/add_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/add_1_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_1_grad/Shape" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape" + input: "gradients/Mean_grad/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile/multiples" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "gradients/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Tile/multiples" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_3_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_3_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 1 + } + } + } + } + } +} +node { + name: "Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Slice_2/begin" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Sub_2/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "concat_1/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat_1/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice_1/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub_1/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "concat/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "logits_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_biases/read" + op: "Identity" + input: "logits_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "logits_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_weights/read" + op: "Identity" + input: "logits_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "hidden_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_biases/read" + op: "Identity" + input: "hidden_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "hidden_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_weights/read" + op: "Identity" + input: "hidden_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\377\377\377\377" + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/depth" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/off_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/on_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 200 + } + } + } +} +node { + name: "mnist_dataset_train_1/random_shuffle_queue" + op: "RandomShuffleQueueV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "capacity" + value { + i: 20000 + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "min_after_dequeue" + value { + i: 4000 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "shapes" + value { + list { + shape { + dim { + size: 28 + } + dim { + size: 28 + } + dim { + size: 1 + } + } + shape { + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + op: "QueueDequeueManyV2" + input: "mnist_dataset_train_1/random_shuffle_queue" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + shape { + unknown_rank: true + } + } + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "Reshape" + op: "Reshape" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + input: "Reshape/shape" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: -1 + } + } + } + } + } +} +node { + name: "MatMul" + op: "MatMul" + input: "Reshape" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add" + op: "Add" + input: "MatMul" + input: "hidden_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Relu" + op: "Relu" + input: "add" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "MatMul_1" + op: "MatMul" + input: "Relu" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add_1" + op: "Add" + input: "MatMul_1" + input: "logits_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "Reshape_1" + op: "Reshape" + input: "add_1" + input: "concat" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot" + op: "OneHot" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" + input: "mnist_dataset_train_2/one_hot/depth" + input: "mnist_dataset_train_2/one_hot/on_value" + input: "mnist_dataset_train_2/one_hot/off_value" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "TI" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "axis" + value { + i: -1 + } + } +} +node { + name: "Reshape_2" + op: "Reshape" + input: "mnist_dataset_train_2/one_hot" + input: "concat_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape_1" + input: "Reshape_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + op: "PreventGradient" + input: "SoftmaxCrossEntropyWithLogits:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "message" + value { + s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_1_grad/Sum_1" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape" + op: "Reshape" + input: "gradients/add_1_grad/Sum" + input: "gradients/add_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_1_grad/Reshape_1" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul_1" + op: "MatMul" + input: "Relu" + input: "gradients/add_1_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul" + op: "MatMul" + input: "gradients/add_1_grad/tuple/control_dependency" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul_1" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/Relu_grad/ReluGrad" + op: "ReluGrad" + input: "gradients/MatMul_1_grad/tuple/control_dependency" + input: "Relu" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Reshape" + input: "gradients/add_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" + input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" + input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" + input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_2" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "Reshape_3" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "Mean" + op: "Mean" + input: "Reshape_3" + input: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "_send_Mean_0" + op: "_Send" + input: "Mean" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "client_terminated" + value { + b: true + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: -5924635994370253548 + } + } + attr { + key: "tensor_name" + value { + s: "Mean:0" + } + } +} +library { +} +versions { + producer: 21 +} diff --git a/tensorflow/tensorboard/components/tf_graph/demo/demo_datasets.json b/tensorflow/tensorboard/components/tf_graph/demo/demo_datasets.json deleted file mode 100644 index f5ca9aada79..00000000000 --- a/tensorflow/tensorboard/components/tf_graph/demo/demo_datasets.json +++ /dev/null @@ -1,123 +0,0 @@ -[ - { - "name": "Mnist Eval", - "path": "mnist_eval.pbtxt" - }, - { - "name": "Mnist with summaries (+stats)", - "path": "mnist_with_summaries.pbtxt", - "runMetadata": [ - { - "tag": "step100", - "path": "mnist_with_summaries_step100.pbtxt" - }, - { - "tag": "step1000", - "path": "mnist_with_summaries_step1000.pbtxt" - } - ] - }, - { - "name": "Mnist Train (with shapes)", - "path": "mnist_train_shapes.pbtxt" - }, - { - "name": "Inception Train (huge)", - "path": "inception_train.pbtxt" - }, - { - "name": "Inception Train Eval", - "path": "inception_train_eval.pbtxt" - }, - { - "name": "Inception Test", - "path": "inception_test_eval.pbtxt" - }, - { - "name": "PTB Word LSTM Train", - "path": "ptb_word_lstm_train.pbtxt" - }, - { - "name": "PTB Word LSTM Train Eval", - "path": "ptb_word_lstm_train_eval.pbtxt" - }, - { - "name": "PTB Word LSTM Test", - "path": "ptb_word_lstm_test_eval.pbtxt" - }, - { - "name": "Cifar10 Train (+stats)", - "path": "cifar10_train.pbtxt", - "runMetadata": [ - { - "tag": "step0", - "path": "cifar10_train_step0.pbtxt" - }, - { - "tag": "step100", - "path": "cifar10_train_step100.pbtxt" - }, - { - "tag": "step200", - "path": "cifar10_train_step200.pbtxt" - }, - { - "tag": "step300", - "path": "cifar10_train_step300.pbtxt" - } - ] - }, - { - "name": "Cifar10 Multi-GPU Train", - "path": "cifar10_multi_gpu_train.pbtxt" - }, - { - "name": "Cifar10 Eval (+stats)", - "path": "cifar10_eval.pbtxt", - "runMetadata": [ - { - "tag": "step0", - "path": "cifar10_eval_step0.pbtxt" - }, - { - "tag": "step10", - "path": "cifar10_eval_step10.pbtxt" - }, - { - "tag": "step20", - "path": "cifar10_eval_step20.pbtxt" - } - ] - }, - { - "name": "Fatcat LSTM", - "path": "fatcat_lstm.pbtxt" - }, - { - "name": "Legacy Inception Renamed", - "path": "legacy_inception_renamed.pbtxt" - }, - { - "name": "Wolfe (Broken)", - "path": "wolfe1.pbtxt" - }, - { - "name": "Wolfe (Fixed)", - "path": "wolfe2.pbtxt" - }, - { - "id": "alex", - "name": "AlexNet", - "path": "alexnet.pbtxt" - }, - { - "id": "alexprivate", - "name": "AlexNet Private", - "path": "alexnet.pbtxt", - "private": true - }, - { - "name": "TestError404", - "path": "nofile" - } -] diff --git a/tensorflow/tensorboard/components/tf_graph/demo/index.html b/tensorflow/tensorboard/components/tf_graph/demo/index.html index c89490f44d4..52e2f0b9340 100644 --- a/tensorflow/tensorboard/components/tf_graph/demo/index.html +++ b/tensorflow/tensorboard/components/tf_graph/demo/index.html @@ -15,32 +15,78 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> + + + + + +TF Graph Demo + + + diff --git a/tensorflow/tensorboard/components/tf_graph_info/tf-node-info.html b/tensorflow/tensorboard/components/tf_graph_info/tf-node-info.html index 1e60cda66ad..66a3034b5b2 100644 --- a/tensorflow/tensorboard/components/tf_graph_info/tf-node-info.html +++ b/tensorflow/tensorboard/components/tf_graph_info/tf-node-info.html @@ -19,9 +19,10 @@ limitations under the License. - + + - + @@ -315,7 +316,7 @@ limitations under the License.
[[_nodeStatsFormattedBytes]]
- + + + diff --git a/tensorflow/tensorboard/components/tf_imports/dagre.html b/tensorflow/tensorboard/components/tf_imports/dagre.html index 48fe39da793..b90dc58e390 100644 --- a/tensorflow/tensorboard/components/tf_imports/dagre.html +++ b/tensorflow/tensorboard/components/tf_imports/dagre.html @@ -16,9 +16,30 @@ limitations under the License. --> - - - + + + + + diff --git a/tensorflow/tensorboard/components/tf_imports/graphlib.html b/tensorflow/tensorboard/components/tf_imports/graphlib.html index 4e19f7b008f..664b855f17f 100644 --- a/tensorflow/tensorboard/components/tf_imports/graphlib.html +++ b/tensorflow/tensorboard/components/tf_imports/graphlib.html @@ -15,5 +15,6 @@ See the License for the specific language governing permissions and limitations under the License. --> - - + + + diff --git a/tensorflow/tensorboard/components/tf_imports/lodash.html b/tensorflow/tensorboard/components/tf_imports/lodash.html index f92aa808799..65ff6a4b032 100644 --- a/tensorflow/tensorboard/components/tf_imports/lodash.html +++ b/tensorflow/tensorboard/components/tf_imports/lodash.html @@ -15,4 +15,4 @@ See the License for the specific language governing permissions and limitations under the License. --> - + diff --git a/tensorflow/tensorboard/components/tf_imports/numericjs.html b/tensorflow/tensorboard/components/tf_imports/numericjs.html new file mode 100644 index 00000000000..81fa9491688 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_imports/numericjs.html @@ -0,0 +1,43 @@ + + + + + diff --git a/tensorflow/tensorboard/components/tf_imports/plottable.html b/tensorflow/tensorboard/components/tf_imports/plottable.html index 57f9c1d6d3a..77ad544d5a0 100644 --- a/tensorflow/tensorboard/components/tf_imports/plottable.html +++ b/tensorflow/tensorboard/components/tf_imports/plottable.html @@ -15,6 +15,30 @@ See the License for the specific language governing permissions and limitations under the License. --> + + - - + + diff --git a/tensorflow/tensorboard/components/tf_imports/threejs.html b/tensorflow/tensorboard/components/tf_imports/threejs.html new file mode 100644 index 00000000000..7f4233b5713 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_imports/threejs.html @@ -0,0 +1,43 @@ + + + + + + diff --git a/tensorflow/tensorboard/components/tf_imports/weblas.html b/tensorflow/tensorboard/components/tf_imports/weblas.html new file mode 100644 index 00000000000..c07020598fc --- /dev/null +++ b/tensorflow/tensorboard/components/tf_imports/weblas.html @@ -0,0 +1,42 @@ + + + + + diff --git a/tensorflow/tensorboard/components/tf_imports_google/README.md b/tensorflow/tensorboard/components/tf_imports_google/README.md deleted file mode 100644 index 60d9cce777b..00000000000 --- a/tensorflow/tensorboard/components/tf_imports_google/README.md +++ /dev/null @@ -1,3 +0,0 @@ -This file acts as import routers for third party javascript libraries, -e.g. Plottable and D3 from `g3/third_party`; it exists to facilitate development -inside google. diff --git a/tensorflow/tensorboard/components/tf_imports_google/d3.html b/tensorflow/tensorboard/components/tf_imports_google/d3.html deleted file mode 100644 index dbfd11aa87e..00000000000 --- a/tensorflow/tensorboard/components/tf_imports_google/d3.html +++ /dev/null @@ -1,18 +0,0 @@ - - - diff --git a/tensorflow/tensorboard/components/tf_imports_google/dagre.html b/tensorflow/tensorboard/components/tf_imports_google/dagre.html deleted file mode 100644 index 5b8b9817410..00000000000 --- a/tensorflow/tensorboard/components/tf_imports_google/dagre.html +++ /dev/null @@ -1,18 +0,0 @@ - - - diff --git a/tensorflow/tensorboard/components/tf_imports_google/graphlib.html b/tensorflow/tensorboard/components/tf_imports_google/graphlib.html deleted file mode 100644 index 56b37ebe4bb..00000000000 --- a/tensorflow/tensorboard/components/tf_imports_google/graphlib.html +++ /dev/null @@ -1,18 +0,0 @@ - - - diff --git a/tensorflow/tensorboard/components/tf_imports_google/lodash.html b/tensorflow/tensorboard/components/tf_imports_google/lodash.html deleted file mode 100644 index eb8fef28831..00000000000 --- a/tensorflow/tensorboard/components/tf_imports_google/lodash.html +++ /dev/null @@ -1,18 +0,0 @@ - - - diff --git a/tensorflow/tensorboard/components/tf_imports_google/plottable.html b/tensorflow/tensorboard/components/tf_imports_google/plottable.html deleted file mode 100644 index 6f9678f9cb2..00000000000 --- a/tensorflow/tensorboard/components/tf_imports_google/plottable.html +++ /dev/null @@ -1,19 +0,0 @@ - - - - diff --git a/tensorflow/tensorboard/components/tf_option_selector/BUILD b/tensorflow/tensorboard/components/tf_option_selector/BUILD new file mode 100644 index 00000000000..3f7eed25cb1 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_option_selector/BUILD @@ -0,0 +1,21 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "tf_option_selector", + srcs = ["tf-option-selector.html"], + path = "/tf-option-selector", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common", + "//tensorflow/tensorboard/components/tf_imports:polymer", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/BUILD b/tensorflow/tensorboard/components/tf_profile_dashboard/BUILD new file mode 100644 index 00000000000..5d04618a545 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_profile_dashboard/BUILD @@ -0,0 +1,25 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "tf_profile_dashboard", + srcs = [ + "tf-profile-dashboard.html", + ], + path = "/tf-profile-dashboard", + deps = [ + "//tensorflow/tensorboard/components/tf_backend", + "//tensorflow/tensorboard/components/tf_dashboard_common", + "//tensorflow/tensorboard/components/tf_graph_controls", + "@org_polymer", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/BUILD new file mode 100644 index 00000000000..3cc20ba352f --- /dev/null +++ b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "demo", + srcs = ["index.html"] + glob(["data/**"]), + path = "/tf-profile-dashboard/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_profile_dashboard", + "//tensorflow/tensorboard/components/tf_trace_viewer:demo", + "@org_polymer", + "@org_polymer_iron_demo_helpers", + "@org_polymer_webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/logdir b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/logdir new file mode 100644 index 00000000000..ecaaa8ac758 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/logdir @@ -0,0 +1 @@ +{"logdir": "/some/fake/logdir"} diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_trace_viewer.json b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_trace_viewer.json new file mode 100644 index 00000000000..bc1a08b535f --- /dev/null +++ b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_trace_viewer.json @@ -0,0 +1,27 @@ +{ + "traceEvents": [ + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "C", + "name": "counter", "args": {"value": 10}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "B", + "name": "A long name that doesnt fit but is exceedingly informative", + "args": {"name_false": false, "value_true": true}}, + {"cat": "PERF", "pid": 22630, "ts": 835, "ph": "I", "s": "p", + "name": "ProcessWideEvent1", "args": {}} + ], + "stackFrames": { + "1": { + "category": "m1", + "name": "main" + }, + "7": { + "category": "m2", + "name": "frame7", + "parent": "1" + }, + "8": { + "category": "m2", + "name": "frame8", + "parent": "1" + } + } +} diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_unsupported.json b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_unsupported.json new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_foo_tag_trace_viewer.json b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_foo_tag_trace_viewer.json new file mode 100644 index 00000000000..e1d57394e35 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_foo_tag_trace_viewer.json @@ -0,0 +1,105 @@ +{ + "traceEvents": [ + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "C", + "name": "counter", "args": {"value": 10}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "B", + "name": "A long name that doesnt fit but is exceedingly informative", + "args": {"name_false": false, "value_true": true}}, + {"cat": "PERF", "pid": 22630, "ts": 835, "ph": "I", "s": "p", + "name": "ProcessWideEvent1", "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 827, "ph": "B", + "name": "Asub with a name that wont fit", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 828, "ph": "E", + "name": "Asub", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 829, "ph": "B", + "name": "Asub", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 15, "ts": 820, "ph": "X", + "name": "Long X type", "args": {}, "sf": 7, "esf": 8}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "E", + "name": "Asub", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X", + "name": "X1", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X", + "name": "X same ts and dur as X1", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "C", + "name": "counter", "args": {"value": 1}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 833, "ph": "E", + "name": "", "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 835, "ph": "I", + "name": "ThreadLevelI1", "args": {}}, + + {"cat": "PERF", "ts": 880, "ph": "I", "s": "g", "name": "GlobalEvent1", + "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 837, "ph": "I", + "name": "ThreadLevelI2", "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 839, "ph": "C", + "name": "counter", "args": {"value": 5}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 840, "ph": "B", + "name": "A not as long a name", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "E", + "name": "A not as long a name", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "C", + "name": "counter", "args": {"value": 1}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "C", + "name": "counter", "args": {"value": 10}}, + + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 850, "ph": "B", + "name": "B", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "E", + "name": "B", "args": {}}, + + {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 827, "ph": "B", + "name": "A", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 835, "ph": "I", + "name": "ThreadLevelImmediate Three", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 845, "ph": "I", + "name": "ThreadLevelImmediate4", "args": {}}, + {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 854, "ph": "E", + "name": "A", "args": {}}, + + {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B", + "name": "B/E over X", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 10, "ts": 860, "ph": "X", + "name": "X", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B", + "name": "B/E under X", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E", + "name": "B/E under X", "args": {}}, + {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E", + "name": "B/E over X", "args": {}}, + + {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 870, "ph": "P", + "name": "SampleA", "args": {}}, + {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 875, "ph": "P", + "name": "SampleB", "args": {}}, + {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 878, "ph": "P", + "name": "SampleC", "args": {}, "sf": 8}, + + {"cat": "__metadata", "pid": 22630, "tid": 22630, "ts": 0, "ph": "M", + "name": "thread_name", "args": {"name": "threadA"}}, + {"cat": "__metadata", "pid": 22630, "tid": 22631, "ts": 0, "ph": "M", + "name": "thread_name", "args": {"name": "threadB"}}, + {"cat": "__metadata", "pid": 22630, "tid": 22632, "ts": 0, "ph": "M", + "name": "thread_name", "args": {"name": "threadC"}} + ], + "stackFrames": { + "1": { + "category": "m1", + "name": "main" + }, + "7": { + "category": "m2", + "name": "frame7", + "parent": "1" + }, + "8": { + "category": "m2", + "name": "frame8", + "parent": "1" + } + } +} diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/tags.json b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/tags.json new file mode 100644 index 00000000000..12ef5bf8b2e --- /dev/null +++ b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/tags.json @@ -0,0 +1 @@ +{"foo": ["trace_viewer"], "bar": ["unsupported", "trace_viewer"]} diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/index.html new file mode 100644 index 00000000000..15064a54f8f --- /dev/null +++ b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/index.html @@ -0,0 +1,75 @@ + + + + + + + + + Profile Dashboard Demo + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/tf-profile-dashboard.html b/tensorflow/tensorboard/components/tf_profile_dashboard/tf-profile-dashboard.html new file mode 100644 index 00000000000..4028f0e0f06 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_profile_dashboard/tf-profile-dashboard.html @@ -0,0 +1,222 @@ + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_runs_selector/BUILD b/tensorflow/tensorboard/components/tf_runs_selector/BUILD new file mode 100644 index 00000000000..30265c8d294 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_runs_selector/BUILD @@ -0,0 +1,27 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "tf_runs_selector", + srcs = [ + "tf-runs-selector.html", + ], + path = "/tf-runs-selector", + deps = [ + "//tensorflow/tensorboard/components/tf_backend", + "//tensorflow/tensorboard/components/tf_dashboard_common", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "@org_polymer_paper_button", + "@org_polymer_paper_dialog", + "@org_polymer_paper_styles", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_runs_selector/tf-runs-selector.html b/tensorflow/tensorboard/components/tf_runs_selector/tf-runs-selector.html new file mode 100644 index 00000000000..6964bb076de --- /dev/null +++ b/tensorflow/tensorboard/components/tf_runs_selector/tf-runs-selector.html @@ -0,0 +1,195 @@ + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD new file mode 100644 index 00000000000..7cc192b4640 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD @@ -0,0 +1,38 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "tf_scalar_dashboard", + srcs = [ + "tf-scalar-dashboard.html", + "tf-smoothing-input.html", + ], + path = "/tf-scalar-dashboard", + deps = [ + "//tensorflow/tensorboard/components/tf_backend", + "//tensorflow/tensorboard/components/tf_color_scale", + "//tensorflow/tensorboard/components/tf_dashboard_common", + "//tensorflow/tensorboard/components/tf_imports:lodash", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_runs_selector", + "//tensorflow/tensorboard/components/vz_line_chart", + "@org_polymer_iron_collapse", + "@org_polymer_paper_checkbox", + "@org_polymer_paper_dropdown_menu", + "@org_polymer_paper_icon_button", + "@org_polymer_paper_input", + "@org_polymer_paper_item", + "@org_polymer_paper_menu", + "@org_polymer_paper_slider", + "@org_polymer_paper_styles", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD new file mode 100644 index 00000000000..0e892b1aa30 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD @@ -0,0 +1,27 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") + +licenses(["notice"]) # Apache 2.0 + +ts_web_library( + name = "demo", + srcs = ["index.html"], + path = "/tf-scalar-dashboard/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_backend", + "//tensorflow/tensorboard/components/tf_imports:d3", + "//tensorflow/tensorboard/components/tf_imports:polymer", + "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", + "//tensorflow/tensorboard/components/tf_scalar_dashboard", + "//tensorflow/tensorboard/demo:demo_data", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_styles", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/logdir b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/logdir new file mode 100644 index 00000000000..b6362b45d77 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/logdir @@ -0,0 +1 @@ +{"logdir": "/foo/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/runs.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/runs.json index da831a00e9d..d45f530763c 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/runs.json +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/runs.json @@ -1,24 +1,4 @@ { - "alpha": { - "scalars": [ - "d1", - "d2", - "d3", - "d4" - ], - "histograms": [], - "images": [], - "audio": [] - }, - "beta": { - "scalars": [ - "d1", - "d2", - "d3", - "d4" - ], - "histograms": [], - "images": [], - "audio": [] - } -} + "run1": {"scalars": ["foo/sin", "foo/cos", "foo/square", "bar/square"]}, + "run2": {"scalars": ["foo/cos", "foo/square", "bar/square"]} +} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars.json new file mode 100644 index 00000000000..bc269395b68 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars.json @@ -0,0 +1 @@ +{"run2": {"foo/cos": [[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]]}, "run1": {"foo/sin": [[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]], "foo/cos": [[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]]}} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d1.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d1.json deleted file mode 100644 index af17f5c3283..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d1.json +++ /dev/null @@ -1 +0,0 @@ -[[1436926051.074826, 84, 0.6990088224411011], [1436926530.99861, 2289, 6.9384379386901855], [1436927011.134076, 7611, 13.698328971862793], [1436927490.984256, 16147, 20.168190002441406], [1436927970.957234, 26087, 20.877344131469727], [1436928450.977514, 36241, 21.269058227539062], [1436928930.989548, 46432, 21.329505920410156], [1436929410.976308, 56629, 21.220420837402344], [1436929890.966395, 66791, 21.190065383911133], [1436930370.958199, 76936, 21.108604431152344], [1436930850.985301, 87083, 21.157001495361328], [1436931331.009261, 97161, 21.02127456665039], [1436931810.966042, 107210, 20.891658782958984], [1436932290.955417, 117262, 20.930112838745117], [1436932770.964496, 127333, 20.986324310302734], [1436933250.962592, 137430, 20.981359481811523], [1436933730.992022, 147528, 21.083036422729492], [1436934210.959831, 157635, 21.092649459838867], [1436934690.97072, 167749, 21.11568832397461], [1436935170.957944, 177869, 21.145965576171875], [1436935650.959987, 188025, 21.215585708618164], [1436936130.997541, 198206, 21.227184295654297], [1436936610.965526, 208395, 21.226459503173828], [1436937090.965581, 218592, 21.264968872070312], [1436937570.964874, 228818, 21.335866928100586], [1436938050.965706, 239021, 21.286521911621094], [1436938531.013159, 249210, 21.20963478088379], [1436939010.957926, 259415, 21.28431510925293], [1436939490.96341, 269637, 21.326831817626953], [1436939970.959372, 279876, 21.38308334350586], [1436940450.963802, 290127, 21.355499267578125], [1436940931.004537, 300349, 21.31337547302246], [1436941410.979614, 310601, 21.405778884887695], [1436941890.979674, 320872, 21.368688583374023], [1436942370.975153, 331131, 21.39077377319336], [1436942850.980459, 341399, 21.41745948791504], [1436943331.000808, 351651, 21.384023666381836], [1436943810.968736, 361904, 21.326438903808594], [1436944290.95947, 372158, 21.367351531982422], [1436944770.955783, 382430, 21.476247787475586], [1436945250.966321, 392684, 21.36678695678711], [1436945731.008667, 402950, 21.349145889282227], [1436946210.977922, 413210, 21.373897552490234], [1436946690.975303, 423463, 21.322399139404297], [1436947170.964596, 433723, 21.341150283813477], [1436947650.955017, 443991, 21.366348266601562], [1436948130.992501, 454271, 21.43844223022461], [1436948610.960555, 464519, 21.36829948425293], [1436949090.961079, 474758, 21.266357421875], [1436949570.971528, 484987, 21.316511154174805], [1436950050.977787, 495228, 21.356050491333008], [1436950531.020035, 505458, 21.31462860107422], [1436951010.959775, 515682, 21.277490615844727], [1436951490.967418, 525910, 21.289737701416016], [1436951970.969778, 536112, 21.2515811920166], [1436952450.956291, 546320, 21.254491806030273], [1436952931.005547, 556541, 21.297870635986328], [1436953410.955758, 566755, 21.320045471191406], [1436953890.959151, 576957, 21.23529624938965], [1436954370.959553, 587165, 21.25132179260254], [1436954850.960546, 597371, 21.23470115661621], [1436955330.989932, 607582, 21.19434356689453], [1436955810.957128, 617790, 21.258535385131836], [1436956290.9763, 627991, 21.221921920776367], [1436956770.957785, 638208, 21.309843063354492], [1436957250.974143, 648404, 21.252185821533203], [1436957731.012441, 658613, 21.265626907348633], [1436958210.980787, 668824, 21.239660263061523], [1436958690.973474, 679034, 21.2642765045166], [1436959170.95825, 689249, 21.303138732910156], [1436959650.959345, 699454, 21.24073600769043], [1436960131.008682, 709664, 21.217615127563477], [1436960610.958074, 719876, 21.251184463500977], [1436961090.963638, 730100, 21.290971755981445], [1436961570.979029, 740316, 21.305265426635742], [1436962050.974645, 750534, 21.27857208251953], [1436962531.055479, 760757, 21.329837799072266], [1436963010.975299, 770964, 21.248849868774414], [1436963490.963107, 781164, 21.19978904724121], [1436963970.965936, 791382, 21.30535888671875], [1436964450.959947, 801590, 21.226255416870117], [1436964931.00587, 811785, 21.242237091064453], [1436965410.977997, 821977, 21.226497650146484], [1436965890.988465, 832189, 21.31219482421875], [1436966370.965612, 842399, 21.283390045166016], [1436966850.965794, 852612, 21.273908615112305], [1436967331.009476, 862825, 21.260452270507812], [1436967810.96767, 873037, 21.315444946289062], [1436968290.959107, 883248, 21.28677749633789], [1436968770.9681, 893452, 21.265335083007812], [1436969250.959332, 903655, 21.252891540527344], [1436969731.055609, 913856, 21.233684539794922], [1436970210.961426, 924047, 21.191429138183594], [1436970690.962999, 934250, 21.23288345336914], [1436971170.989107, 944430, 21.17190170288086], [1436971650.956015, 954634, 21.275972366333008], [1436972131.006841, 964844, 21.278474807739258], [1436972610.981754, 975045, 21.25553321838379], [1436973090.961548, 985239, 21.21686553955078], [1436973570.960013, 995439, 21.26004981994629], [1436974050.975653, 1005642, 21.25356101989746], [1436974530.988571, 1015842, 21.23944664001465], [1436975010.95851, 1026048, 21.293363571166992], [1436975490.97355, 1036253, 21.277101516723633], [1436975970.960916, 1046451, 21.242155075073242], [1436976450.990263, 1056636, 21.182037353515625], [1436976930.999578, 1066834, 21.21113395690918], [1436977410.962637, 1077031, 21.230762481689453], [1436977890.970389, 1087222, 21.232444763183594], [1436978370.959059, 1097405, 21.202342987060547], [1436978850.956562, 1107601, 21.23992156982422], [1436979331.021134, 1117786, 21.197628021240234], [1436979810.958593, 1127973, 21.2270565032959], [1436980290.958763, 1138163, 21.250303268432617], [1436980770.967171, 1148348, 21.215538024902344], [1436981250.960473, 1158540, 21.277185440063477], [1436981731.009465, 1168733, 21.268449783325195], [1436982210.960797, 1178930, 21.268077850341797], [1436982690.959709, 1189129, 21.243141174316406], [1436983170.961963, 1199327, 21.21793556213379], [1436983650.958504, 1209524, 21.2817440032959], [1436984130.998057, 1219726, 21.261478424072266], [1436984610.958945, 1229936, 21.300107955932617], [1436985090.978825, 1240145, 21.326183319091797], [1436985570.993741, 1250311, 21.115875244140625], [1436986050.965608, 1260436, 21.19010353088379], [1436986531.026713, 1270611, 21.183719635009766], [1436987010.969056, 1280784, 21.273176193237305], [1436987490.975071, 1290959, 21.182931900024414], [1436987970.96007, 1301147, 21.260244369506836], [1436988450.966092, 1311328, 21.225025177001953], [1436988931.004917, 1321514, 21.242164611816406], [1436989410.980351, 1331709, 21.19801139831543], [1436989890.975192, 1341910, 21.273555755615234], [1436990370.964941, 1352090, 21.175983428955078], [1436990850.973647, 1362240, 21.13412094116211], [1436991330.999346, 1372396, 21.153064727783203], [1436991811.003573, 1382550, 21.155475616455078], [1436992290.962706, 1392710, 21.17011833190918], [1436992770.999149, 1402862, 21.128713607788086], [1436993250.965124, 1413020, 21.1361026763916], [1436993731.020464, 1423164, 21.157777786254883], [1436994210.966935, 1433312, 21.119478225708008], [1436994690.962803, 1443468, 21.161104202270508], [1436995170.972952, 1453657, 21.11492919921875], [1436995650.976233, 1463820, 21.194231033325195], [1436996130.990524, 1473980, 21.169816970825195], [1436996610.97302, 1484152, 21.18223762512207], [1436997090.958457, 1494308, 21.1954402923584], [1436997570.980333, 1504463, 21.140769958496094], [1436998050.969869, 1514618, 21.162744522094727], [1436998530.99688, 1524770, 21.139591217041016], [1436999010.970375, 1534905, 21.107114791870117], [1436999490.960775, 1545070, 21.233396530151367], [1436999970.965087, 1555223, 21.201074600219727], [1437000450.969008, 1565370, 21.147083282470703], [1437000931.007425, 1575517, 21.108510971069336], [1437001410.962798, 1585666, 21.11674690246582], [1437001890.966192, 1595826, 21.17819595336914], [1437002370.961814, 1605980, 21.157669067382812], [1437002850.962206, 1616145, 21.212690353393555], [1437003330.994816, 1626291, 21.177446365356445], [1437003810.966017, 1636448, 21.17884063720703], [1437004290.959479, 1646599, 21.150310516357422], [1437004770.965083, 1656754, 21.21011734008789], [1437005250.958234, 1666902, 21.14912986755371], [1437005731.003528, 1677043, 21.125459671020508], [1437006210.961371, 1687192, 21.124374389648438], [1437006690.962663, 1697338, 21.150362014770508], [1437007170.961639, 1707484, 21.16637420654297], [1437007650.972242, 1717625, 21.163259506225586], [1437008131.003191, 1727767, 21.167280197143555], [1437008610.962644, 1737913, 21.174945831298828], [1437009090.964129, 1748068, 21.17894172668457], [1437009570.962582, 1758219, 21.116622924804688], [1437010050.984863, 1768384, 21.23469352722168], [1437010531.002295, 1778534, 21.143510818481445], [1437011010.961803, 1788677, 21.159791946411133], [1437011490.974074, 1798822, 21.119792938232422], [1437011970.959982, 1808958, 21.10943603515625], [1437012450.95932, 1819091, 21.123899459838867], [1437012931.004909, 1829227, 21.094532012939453], [1437013410.957751, 1839374, 21.200057983398438], [1437013890.960506, 1849509, 21.10895538330078], [1437014370.96113, 1859653, 21.108680725097656], [1437014850.962876, 1869791, 21.141136169433594], [1437015331.009875, 1879944, 21.160165786743164], [1437015810.960671, 1890090, 21.158742904663086], [1437016290.970743, 1900242, 21.16562271118164], [1437016770.961673, 1910391, 21.141860961914062], [1437017250.96735, 1920551, 21.19420051574707], [1437017731.000324, 1930702, 21.16814422607422], [1437018210.967878, 1940856, 21.125978469848633], [1437018690.962742, 1951005, 21.15043067932129], [1437019170.975774, 1961158, 21.157419204711914], [1437019650.964573, 1971309, 21.150177001953125], [1437020130.999343, 1981461, 21.124492645263672], [1437020610.960696, 1991611, 21.109933853149414], [1437021090.958597, 2001766, 21.169754028320312], [1437021570.964477, 2011919, 21.13479995727539], [1437022050.966522, 2022063, 21.131561279296875], [1437022531.005607, 2032219, 21.135629653930664], [1437023010.970667, 2042380, 21.207313537597656], [1437023490.964885, 2052534, 21.108623504638672], [1437023970.965596, 2062691, 21.14097023010254], [1437024450.962296, 2072837, 21.129037857055664], [1437024931.00395, 2082982, 21.077030181884766], [1437025410.96602, 2093128, 21.13152503967285], [1437025890.961753, 2103274, 21.117740631103516], [1437026370.962022, 2113424, 21.141584396362305], [1437026850.975475, 2123570, 21.143577575683594], [1437027331.009277, 2133721, 21.175586700439453], [1437027810.97206, 2143857, 21.099014282226562], [1437028290.961523, 2154015, 21.141523361206055], [1437028770.964366, 2164168, 21.141345977783203], [1437029250.962109, 2174320, 21.14827537536621], [1437029731.003068, 2184453, 21.086946487426758], [1437030210.960946, 2194602, 21.1590576171875], [1437030690.966681, 2204754, 21.17353057861328], [1437031170.961207, 2214899, 21.133989334106445], [1437031650.962809, 2225062, 21.14800453186035], [1437032130.997644, 2235215, 21.15397834777832], [1437032610.962999, 2245366, 21.15763282775879], [1437033090.962192, 2255521, 21.133577346801758], [1437033570.963341, 2265657, 21.058490753173828], [1437034050.979501, 2275787, 21.079614639282227], [1437034531.003514, 2285923, 21.12677574157715], [1437035010.960984, 2296058, 21.100793838500977], [1437035490.97325, 2306176, 21.10753059387207], [1437035970.969759, 2316297, 21.100393295288086], [1437036450.962305, 2326428, 21.041208267211914], [1437036931.001785, 2336571, 21.15167999267578], [1437037410.967681, 2346709, 21.09291648864746], [1437037890.963194, 2356854, 21.18524932861328], [1437038370.96445, 2366985, 21.116247177124023], [1437038850.960718, 2377124, 21.125469207763672], [1437039331.003148, 2387259, 21.132274627685547], [1437039810.974007, 2397400, 21.119945526123047], [1437040290.983415, 2407539, 21.154672622680664], [1437040770.961836, 2417667, 21.066741943359375], [1437041250.964281, 2427791, 21.126564025878906], [1437041731.0196, 2437923, 21.1062068939209], [1437042210.962927, 2448056, 21.124549865722656], [1437042690.964392, 2458193, 21.13232421875], [1437043170.972024, 2468318, 21.066423416137695], [1437043650.966111, 2478449, 21.123788833618164], [1437044131.030028, 2488576, 21.138349533081055], [1437044610.962532, 2498717, 21.11895179748535], [1437045090.965094, 2508839, 21.019609451293945], [1437045570.963352, 2518972, 21.079254150390625], [1437046050.96194, 2529106, 21.15033531188965], [1437046530.995016, 2539243, 21.11912727355957], [1437047010.963313, 2549369, 21.08464813232422], [1437047490.963943, 2559509, 21.133895874023438], [1437047970.958612, 2569646, 21.108659744262695], [1437048450.962392, 2579776, 21.084848403930664], [1437048931.005408, 2589906, 21.092708587646484], [1437049410.984115, 2600033, 21.130634307861328], [1437049890.964103, 2610162, 21.074010848999023], [1437050370.960886, 2620282, 21.086149215698242], [1437050850.959795, 2630402, 21.08969497680664], [1437051331.008292, 2640533, 21.134498596191406], [1437051810.96622, 2650643, 21.065444946289062], [1437052290.98584, 2660774, 21.120830535888672], [1437052770.967707, 2670900, 21.085134506225586], [1437053250.978851, 2681021, 21.037155151367188], [1437053731.021686, 2691151, 21.09203338623047], [1437054210.971744, 2701273, 21.048450469970703], [1437054690.966686, 2711425, 21.048809051513672], [1437055170.964463, 2721564, 21.13330078125], [1437055650.97301, 2731694, 21.097095489501953], [1437056130.997053, 2741810, 21.031536102294922], [1437056610.968681, 2751927, 21.04400634765625], [1437057090.976676, 2762049, 21.114444732666016], [1437057570.962334, 2772169, 21.06243896484375], [1437058050.969524, 2782292, 21.12563133239746], [1437058531.012918, 2792420, 21.12433433532715], [1437059010.972868, 2802545, 21.067407608032227], [1437059490.96188, 2812684, 21.099285125732422], [1437059970.965083, 2822806, 21.08357810974121], [1437060450.964845, 2832940, 21.142192840576172], [1437060931.011947, 2843080, 21.109895706176758], [1437061410.963414, 2853223, 21.13360023498535], [1437061890.969303, 2863361, 21.152849197387695], [1437062370.963703, 2873490, 21.08356285095215], [1437062850.964392, 2883627, 21.115087509155273], [1437063331.025516, 2893758, 21.13198471069336], [1437063810.962087, 2903877, 21.084623336791992], [1437064290.973818, 2914013, 21.14010238647461], [1437064770.967792, 2924145, 21.108346939086914], [1437065250.95886, 2934291, 21.1142635345459], [1437065731.01002, 2944434, 21.17418670654297], [1437066210.959306, 2954576, 21.084075927734375], [1437066690.960644, 2964724, 21.125164031982422], [1437067170.969539, 2974890, 21.200775146484375], [1437067650.960018, 2985036, 21.14740562438965], [1437068130.990731, 2995179, 21.11964225769043], [1437068610.960429, 3005322, 21.141313552856445], [1437069090.95752, 3015461, 21.082963943481445], [1437069570.974879, 3025595, 21.12288475036621], [1437070050.95761, 3035734, 21.107513427734375], [1437070531.0013, 3045868, 21.171630859375], [1437071010.961705, 3056004, 21.066505432128906], [1437071490.961495, 3066137, 21.10834312438965], [1437071970.978122, 3076267, 21.08027458190918], [1437072450.963299, 3086399, 21.089733123779297], [1437072931.018382, 3096524, 21.133176803588867], [1437073050.962102, 3099048, 21.041847229003906], [1437073170.96983, 3101584, 21.131967544555664], [1437073290.957895, 3104118, 21.118793487548828]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d2.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d2.json deleted file mode 100644 index 92bb4143482..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d2.json +++ /dev/null @@ -1 +0,0 @@ -[[1436925978.257845, 7, 0.04500000178813934], [1436926413.945391, 1476, 0.04500000178813934], [1436926893.945037, 6006, 0.04500000178813934], [1436927373.995472, 13786, 0.04500000178813934], [1436927853.989794, 23650, 0.04500000178813934], [1436928334.132361, 33755, 0.04500000178813934], [1436928813.973288, 43941, 0.04500000178813934], [1436929293.975949, 54146, 0.04500000178813934], [1436929773.992781, 64316, 0.04500000178813934], [1436930253.997415, 74465, 0.04500000178813934], [1436930734.203004, 84611, 0.04230000078678131], [1436931214.03644, 94700, 0.04230000078678131], [1436931694.094564, 104766, 0.04230000078678131], [1436932174.114955, 114817, 0.04230000078678131], [1436932654.161382, 124880, 0.04230000078678131], [1436933133.960214, 134977, 0.04230000078678131], [1436933614.044337, 145062, 0.04230000078678131], [1436934094.166206, 155169, 0.04230000078678131], [1436934574.106036, 165284, 0.03976200148463249], [1436935054.150647, 175402, 0.03976200148463249], [1436935533.819562, 185538, 0.03976200148463249], [1436936013.710422, 195712, 0.03976200148463249], [1436936493.609025, 205906, 0.03976200148463249], [1436936973.683892, 216099, 0.03976200148463249], [1436937454.138383, 226331, 0.03976200148463249], [1436937933.838475, 236532, 0.03976200148463249], [1436938413.89688, 246724, 0.0373762808740139], [1436938894.018652, 256925, 0.0373762808740139], [1436939373.69067, 267137, 0.0373762808740139], [1436939853.673692, 277369, 0.0373762808740139], [1436940333.651346, 287620, 0.0373762808740139], [1436940813.599579, 297848, 0.0373762808740139], [1436941293.596313, 308088, 0.0373762808740139], [1436941773.659172, 318362, 0.0373762808740139], [1436942253.648479, 328621, 0.03513370454311371], [1436942733.752284, 338892, 0.03513370454311371], [1436943213.621881, 349144, 0.03513370454311371], [1436943693.698743, 359399, 0.03513370454311371], [1436944173.578463, 369649, 0.03513370454311371], [1436944653.692217, 379912, 0.03513370454311371], [1436945133.677298, 390180, 0.03513370454311371], [1436945613.572411, 400445, 0.03302568197250366], [1436946093.56123, 410703, 0.03302568197250366], [1436946573.542364, 420958, 0.03302568197250366], [1436947053.616578, 431216, 0.03302568197250366], [1436947533.636973, 441483, 0.03302568197250366], [1436948013.541574, 451751, 0.03302568197250366], [1436948493.560223, 462015, 0.03302568197250366], [1436948973.512541, 472260, 0.03302568197250366], [1436949453.550055, 482483, 0.031044140458106995], [1436949933.828011, 492731, 0.031044140458106995], [1436950413.603177, 502957, 0.031044140458106995], [1436950893.563009, 513185, 0.031044140458106995], [1436951373.620887, 523410, 0.031044140458106995], [1436951853.61941, 533618, 0.031044140458106995], [1436952333.694447, 543828, 0.031044140458106995], [1436952813.621004, 554042, 0.031044140458106995], [1436953293.588156, 564251, 0.02918149158358574], [1436953773.599734, 574464, 0.02918149158358574], [1436954253.621309, 584672, 0.02918149158358574], [1436954733.738119, 594882, 0.02918149158358574], [1436955213.56617, 605091, 0.02918149158358574], [1436955693.585366, 615296, 0.02918149158358574], [1436956173.626395, 625501, 0.02918149158358574], [1436956653.601937, 635705, 0.02918149158358574], [1436957133.665878, 645915, 0.02743060328066349], [1436957613.584762, 656116, 0.02743060328066349], [1436958093.549783, 666331, 0.02743060328066349], [1436958573.646778, 676543, 0.02743060328066349], [1436959053.585655, 686750, 0.02743060328066349], [1436959533.679696, 696961, 0.02743060328066349], [1436960013.633292, 707173, 0.02743060328066349], [1436960493.578778, 717383, 0.02743060328066349], [1436960973.596715, 727598, 0.025784766301512718], [1436961453.625644, 737818, 0.025784766301512718], [1436961933.740339, 748040, 0.025784766301512718], [1436962413.573845, 758252, 0.025784766301512718], [1436962893.610678, 768470, 0.025784766301512718], [1436963373.642878, 778674, 0.025784766301512718], [1436963853.558388, 788877, 0.025784766301512718], [1436964333.658419, 799099, 0.025784766301512718], [1436964813.573319, 809289, 0.024237681180238724], [1436965293.542098, 819484, 0.024237681180238724], [1436965773.545453, 829687, 0.024237681180238724], [1436966253.586517, 839901, 0.024237681180238724], [1436966733.639348, 850120, 0.024237681180238724], [1436967213.697288, 860330, 0.024237681180238724], [1436967693.617172, 870539, 0.024237681180238724], [1436968173.593885, 880748, 0.024237681180238724], [1436968653.560836, 890955, 0.022783419117331505], [1436969133.676337, 901164, 0.022783419117331505], [1436969613.506638, 911358, 0.022783419117331505], [1436970093.595964, 921560, 0.022783419117331505], [1436970573.541227, 931756, 0.022783419117331505], [1436971053.624316, 941945, 0.022783419117331505], [1436971533.655543, 952138, 0.022783419117331505], [1436972013.604738, 962349, 0.02141641452908516], [1436972493.613199, 972551, 0.02141641452908516], [1436972973.501155, 982746, 0.02141641452908516], [1436973453.64842, 992945, 0.02141641452908516], [1436973933.689516, 1003147, 0.02141641452908516], [1436974413.577769, 1013350, 0.02141641452908516], [1436974893.542281, 1023545, 0.02141641452908516], [1436975373.638453, 1033759, 0.02141641452908516], [1436975853.524388, 1043955, 0.02013142965734005], [1436976333.625792, 1054148, 0.02013142965734005], [1436976813.610661, 1064342, 0.02013142965734005], [1436977293.601581, 1074539, 0.02013142965734005], [1436977773.575627, 1084733, 0.02013142965734005], [1436978253.564972, 1094914, 0.02013142965734005], [1436978733.673144, 1105109, 0.02013142965734005], [1436979213.540585, 1115293, 0.02013142965734005], [1436979693.699591, 1125483, 0.018923543393611908], [1436980173.613012, 1135670, 0.018923543393611908], [1436980653.575769, 1145862, 0.018923543393611908], [1436981133.719264, 1156045, 0.018923543393611908], [1436981613.563551, 1166236, 0.018923543393611908], [1436982093.553233, 1176436, 0.018923543393611908], [1436982573.577846, 1186636, 0.018923543393611908], [1436983053.605749, 1196837, 0.018923543393611908], [1436983533.684994, 1207025, 0.017788130789995193], [1436984013.561492, 1217233, 0.017788130789995193], [1436984493.629873, 1227437, 0.017788130789995193], [1436984973.606714, 1237643, 0.017788130789995193], [1436985453.690084, 1247835, 0.017788130789995193], [1436985933.711388, 1257951, 0.017788130789995193], [1436986413.598807, 1268125, 0.017788130789995193], [1436986893.631797, 1278290, 0.017788130789995193], [1436987373.596962, 1288473, 0.016720842570066452], [1436987853.555549, 1298650, 0.016720842570066452], [1436988333.722032, 1308841, 0.016720842570066452], [1436988813.55697, 1319018, 0.016720842570066452], [1436989293.756905, 1329221, 0.016720842570066452], [1436989773.665141, 1339417, 0.016720842570066452], [1436990253.768302, 1349610, 0.016720842570066452], [1436990733.708919, 1359759, 0.016720842570066452], [1436991213.663033, 1369914, 0.01571759209036827], [1436991693.730925, 1380074, 0.01571759209036827], [1436992173.751791, 1390224, 0.01571759209036827], [1436992653.758682, 1400383, 0.01571759209036827], [1436993133.835604, 1410542, 0.01571759209036827], [1436993613.674655, 1420684, 0.01571759209036827], [1436994093.747454, 1430832, 0.01571759209036827], [1436994573.768973, 1440986, 0.01571759209036827], [1436995053.666661, 1451174, 0.014774537645280361], [1436995533.83439, 1461345, 0.014774537645280361], [1436996013.556996, 1471495, 0.014774537645280361], [1436996493.635477, 1481663, 0.014774537645280361], [1436996973.668684, 1491822, 0.014774537645280361], [1436997453.59326, 1501979, 0.014774537645280361], [1436997933.774019, 1512139, 0.014774537645280361], [1436998413.575162, 1522290, 0.01388806477189064], [1436998893.640468, 1532431, 0.01388806477189064], [1436999373.551661, 1542579, 0.01388806477189064], [1436999853.57906, 1552734, 0.01388806477189064], [1437000333.680409, 1562888, 0.01388806477189064], [1437000813.602383, 1573037, 0.01388806477189064], [1437001293.610337, 1583190, 0.01388806477189064], [1437001773.618199, 1593341, 0.01388806477189064], [1437002253.572966, 1603497, 0.013054781593382359], [1437002733.67994, 1613657, 0.013054781593382359], [1437003213.583266, 1623809, 0.013054781593382359], [1437003693.639943, 1633966, 0.013054781593382359], [1437004173.568287, 1644113, 0.013054781593382359], [1437004653.610772, 1654268, 0.013054781593382359], [1437005133.663045, 1664424, 0.013054781593382359], [1437005613.580984, 1674567, 0.013054781593382359], [1437006093.601019, 1684715, 0.01227149460464716], [1437006573.625314, 1694857, 0.01227149460464716], [1437007053.584514, 1704999, 0.01227149460464716], [1437007533.719303, 1715150, 0.01227149460464716], [1437008013.604962, 1725282, 0.01227149460464716], [1437008493.655091, 1735432, 0.01227149460464716], [1437008973.640165, 1745584, 0.01227149460464716], [1437009453.715067, 1755742, 0.01227149460464716], [1437009933.765712, 1765896, 0.011535204015672207], [1437010413.632128, 1776052, 0.011535204015672207], [1437010893.66766, 1786195, 0.011535204015672207], [1437011373.636164, 1796346, 0.011535204015672207], [1437011853.631224, 1806481, 0.011535204015672207], [1437012333.706205, 1816617, 0.011535204015672207], [1437012813.61987, 1826754, 0.011535204015672207], [1437013293.479904, 1836883, 0.011535204015672207], [1437013773.604574, 1847029, 0.010843091644346714], [1437014253.618884, 1857175, 0.010843091644346714], [1437014733.756419, 1867312, 0.010843091644346714], [1437015213.638607, 1877459, 0.010843091644346714], [1437015693.625763, 1887608, 0.010843091644346714], [1437016173.63194, 1897759, 0.010843091644346714], [1437016653.609074, 1907909, 0.010843091644346714], [1437017133.717601, 1918074, 0.010843091644346714], [1437017613.716011, 1928220, 0.010192506946623325], [1437018093.626005, 1938377, 0.010192506946623325], [1437018573.626522, 1948523, 0.010192506946623325], [1437019053.648174, 1958678, 0.010192506946623325], [1437019533.803011, 1968831, 0.010192506946623325], [1437020013.667751, 1978978, 0.010192506946623325], [1437020493.659028, 1989133, 0.010192506946623325], [1437020973.657346, 1999287, 0.010192506946623325], [1437021453.650634, 2009437, 0.00958095584064722], [1437021933.848661, 2019588, 0.00958095584064722], [1437022413.674963, 2029736, 0.00958095584064722], [1437022893.69086, 2039894, 0.00958095584064722], [1437023373.68883, 2050054, 0.00958095584064722], [1437023853.686116, 2060205, 0.00958095584064722], [1437024333.763876, 2070362, 0.00958095584064722], [1437024813.707845, 2080507, 0.00958095584064722], [1437025293.483294, 2090645, 0.009006098844110966], [1437025773.695712, 2100793, 0.009006098844110966], [1437026253.672994, 2110943, 0.009006098844110966], [1437026733.780775, 2121094, 0.009006098844110966], [1437027213.617849, 2131235, 0.009006098844110966], [1437027693.694451, 2141382, 0.009006098844110966], [1437028173.68596, 2151537, 0.009006098844110966], [1437028653.584833, 2161685, 0.009006098844110966], [1437029133.792483, 2171839, 0.00846573244780302], [1437029613.661672, 2181977, 0.00846573244780302], [1437030093.641009, 2192118, 0.00846573244780302], [1437030573.656274, 2202268, 0.00846573244780302], [1437031053.643631, 2212416, 0.00846573244780302], [1437031533.777478, 2222583, 0.00846573244780302], [1437032013.704008, 2232736, 0.00846573244780302], [1437032493.638393, 2242882, 0.007957788184285164], [1437032973.684986, 2253041, 0.007957788184285164], [1437033453.699562, 2263183, 0.007957788184285164], [1437033933.918074, 2273320, 0.007957788184285164], [1437034413.596351, 2283443, 0.007957788184285164], [1437034893.640496, 2293579, 0.007957788184285164], [1437035373.637761, 2303701, 0.007957788184285164], [1437035853.669947, 2313823, 0.007957788184285164], [1437036333.78905, 2323961, 0.0074803209863603115], [1437036813.699727, 2334089, 0.0074803209863603115], [1437037293.662592, 2344235, 0.0074803209863603115], [1437037773.66716, 2354364, 0.0074803209863603115], [1437038253.603687, 2364507, 0.0074803209863603115], [1437038733.78864, 2374644, 0.0074803209863603115], [1437039213.641799, 2384782, 0.0074803209863603115], [1437039693.687078, 2394923, 0.0074803209863603115], [1437040173.635717, 2405058, 0.0070315017364919186], [1437040653.673331, 2415194, 0.0070315017364919186], [1437041133.764768, 2425322, 0.0070315017364919186], [1437041613.629279, 2435449, 0.0070315017364919186], [1437042093.703985, 2445575, 0.0070315017364919186], [1437042573.496029, 2455712, 0.0070315017364919186], [1437043053.686022, 2465844, 0.0070315017364919186], [1437043533.731929, 2475974, 0.0070315017364919186], [1437044013.636245, 2486095, 0.006609611678868532], [1437044493.69923, 2496238, 0.006609611678868532], [1437044973.652155, 2506373, 0.006609611678868532], [1437045453.691467, 2516497, 0.006609611678868532], [1437045933.935804, 2526637, 0.006609611678868532], [1437046413.635583, 2536770, 0.006609611678868532], [1437046893.626337, 2546896, 0.006609611678868532], [1437047373.67437, 2557029, 0.006609611678868532], [1437047853.652939, 2567169, 0.0062130349688231945], [1437048333.778436, 2577306, 0.0062130349688231945], [1437048813.654248, 2587433, 0.0062130349688231945], [1437049293.610609, 2597552, 0.0062130349688231945], [1437049773.646573, 2607690, 0.0062130349688231945], [1437050253.667925, 2617808, 0.0062130349688231945], [1437050733.735291, 2627933, 0.0062130349688231945], [1437051213.620222, 2638053, 0.0062130349688231945], [1437051693.601978, 2648171, 0.005840253084897995], [1437052173.634985, 2658299, 0.005840253084897995], [1437052653.687176, 2668425, 0.005840253084897995], [1437053133.762819, 2678556, 0.005840253084897995], [1437053613.643698, 2688671, 0.005840253084897995], [1437054093.673047, 2698804, 0.005840253084897995], [1437054573.667371, 2708956, 0.005840253084897995], [1437055053.650441, 2719087, 0.005840253084897995], [1437055533.778469, 2729219, 0.005489837843924761], [1437056013.694082, 2739343, 0.005489837843924761], [1437056493.674871, 2749458, 0.005489837843924761], [1437056973.700234, 2759575, 0.005489837843924761], [1437057453.666129, 2769697, 0.005489837843924761], [1437057933.848506, 2779821, 0.005489837843924761], [1437058413.643799, 2789941, 0.005489837843924761], [1437058893.715386, 2800076, 0.005489837843924761], [1437059373.62596, 2810207, 0.005160447675734758], [1437059853.650848, 2820334, 0.005160447675734758], [1437060333.792248, 2830465, 0.005160447675734758], [1437060813.682955, 2840600, 0.005160447675734758], [1437061293.681795, 2850745, 0.005160447675734758], [1437061773.691182, 2860880, 0.005160447675734758], [1437062253.662987, 2871013, 0.005160447675734758], [1437062733.760419, 2881153, 0.005160447675734758], [1437063213.651969, 2891278, 0.004850820638239384], [1437063693.723523, 2901406, 0.004850820638239384], [1437064173.68663, 2911533, 0.004850820638239384], [1437064653.547643, 2921667, 0.004850820638239384], [1437065133.62645, 2931813, 0.004850820638239384], [1437065613.566569, 2941947, 0.004850820638239384], [1437066093.537804, 2952102, 0.004850820638239384], [1437066573.529332, 2962243, 0.004850820638239384], [1437067053.520098, 2972400, 0.004559771623462439], [1437067533.605733, 2982561, 0.004559771623462439], [1437068013.535467, 2992698, 0.004559771623462439], [1437068493.559976, 3002839, 0.004559771623462439], [1437068973.558743, 3012983, 0.004559771623462439], [1437069453.562661, 3023116, 0.004559771623462439], [1437069933.627071, 3033256, 0.004559771623462439], [1437070413.574131, 3043386, 0.004286185372620821], [1437070893.658803, 3053528, 0.004286185372620821], [1437071373.638711, 3063659, 0.004286185372620821], [1437071853.621384, 3073794, 0.004286185372620821], [1437072333.665269, 3083926, 0.004286185372620821], [1437072813.584388, 3094040, 0.004286185372620821], [1437073293.569178, 3104172, 0.004286185372620821]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d3.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d3.json deleted file mode 100644 index 69191b91544..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d3.json +++ /dev/null @@ -1 +0,0 @@ -[[1436925978.257845, 7, 0.0], [1436927853.989794, 23650, 7360.0], [1436929773.992781, 64316, 7360.0], [1436931694.094564, 104766, 7360.0], [1436933614.044337, 145062, 7360.0], [1436935533.819562, 185538, 7360.0], [1436937454.138383, 226331, 7360.0], [1436939373.69067, 267137, 7360.0], [1436941293.596313, 308088, 7360.0], [1436943213.621881, 349144, 7360.0], [1436945133.677298, 390180, 7360.0], [1436947053.616578, 431216, 7360.0], [1436948973.512541, 472260, 7360.0], [1436950893.563009, 513185, 7360.0], [1436952813.621004, 554042, 7360.0], [1436954733.738119, 594882, 7360.0], [1436956653.601937, 635705, 7360.0], [1436958573.646778, 676543, 7360.0], [1436960493.578778, 717383, 7360.0], [1436962413.573845, 758252, 7360.0], [1436964333.658419, 799099, 7360.0], [1436966253.586517, 839901, 7360.0], [1436968173.593885, 880748, 7360.0], [1436970093.595964, 921560, 7360.0], [1436972013.604738, 962349, 7360.0], [1436973933.689516, 1003147, 7360.0], [1436975853.524388, 1043955, 7360.0], [1436977773.575627, 1084733, 7360.0], [1436979693.699591, 1125483, 7360.0], [1436981613.563551, 1166236, 7360.0], [1436983533.684994, 1207025, 7360.0], [1436985453.690084, 1247835, 7360.0], [1436987373.596962, 1288473, 7360.0], [1436989293.756905, 1329221, 7360.0], [1436991213.663033, 1369914, 7360.0], [1436993133.835604, 1410542, 7360.0], [1436995053.666661, 1451174, 7360.0], [1436996973.668684, 1491822, 7360.0], [1436998893.640468, 1532431, 7360.0], [1437000813.602383, 1573037, 7360.0], [1437002733.67994, 1613657, 7360.0], [1437004653.610772, 1654268, 7360.0], [1437006573.625314, 1694857, 7360.0], [1437008493.655091, 1735432, 7360.0], [1437010413.632128, 1776052, 7360.0], [1437012333.706205, 1816617, 7360.0], [1437014253.618884, 1857175, 7360.0], [1437016173.63194, 1897759, 7360.0], [1437018093.626005, 1938377, 7360.0], [1437020013.667751, 1978978, 7360.0], [1437021933.848661, 2019588, 7360.0], [1437023853.686116, 2060205, 7360.0], [1437025773.695712, 2100793, 7360.0], [1437027693.694451, 2141382, 7360.0], [1437029613.661672, 2181977, 7360.0], [1437031533.777478, 2222583, 7360.0], [1437033453.699562, 2263183, 7360.0], [1437035373.637761, 2303701, 7360.0], [1437037293.662592, 2344235, 7360.0], [1437039213.641799, 2384782, 7360.0], [1437041133.764768, 2425322, 7360.0], [1437043053.686022, 2465844, 7360.0], [1437044973.652155, 2506373, 7360.0], [1437046893.626337, 2546896, 7862.0], [1437048813.654248, 2587433, 7862.0], [1437050733.735291, 2627933, 7862.0], [1437052653.687176, 2668425, 7862.0], [1437054573.667371, 2708956, 7862.0], [1437056493.674871, 2749458, 7862.0], [1437058413.643799, 2789941, 7862.0], [1437060333.792248, 2830465, 7862.0], [1437062253.662987, 2871013, 7862.0], [1437064173.68663, 2911533, 7862.0], [1437066093.537804, 2952102, 7862.0], [1437068013.535467, 2992698, 7862.0], [1437069933.627071, 3033256, 7862.0], [1437071853.621384, 3073794, 7862.0], [1437072333.665269, 3083926, 7862.0], [1437072813.584388, 3094040, 7862.0], [1437073293.569178, 3104172, 7862.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d4.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d4.json deleted file mode 100644 index caf1ae6e7f7..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/alpha/d4.json +++ /dev/null @@ -1 +0,0 @@ -[[1436925978.257845, 7, 2.461352825164795], [1436926413.945391, 1476, 12.772720336914062], [1436926893.945037, 6006, 12.195232391357422], [1436927373.995472, 13786, 11.528279304504395], [1436927853.989794, 23650, 10.722719192504883], [1436928334.132361, 33755, 10.215253829956055], [1436928813.973288, 43941, 9.730447769165039], [1436929293.975949, 54146, 9.399007797241211], [1436929773.992781, 64316, 9.1018648147583], [1436930253.997415, 74465, 8.961446762084961], [1436930734.203004, 84611, 8.757476806640625], [1436931214.03644, 94700, 8.4615478515625], [1436931694.094564, 104766, 8.506814956665039], [1436932174.114955, 114817, 8.246719360351562], [1436932654.161382, 124880, 8.329349517822266], [1436933133.960214, 134977, 7.90853214263916], [1436933614.044337, 145062, 8.192558288574219], [1436934094.166206, 155169, 7.865443229675293], [1436934574.106036, 165284, 7.910976886749268], [1436935054.150647, 175402, 7.925509929656982], [1436935533.819562, 185538, 7.866455078125], [1436936013.710422, 195712, 7.9123406410217285], [1436936493.609025, 205906, 7.748654842376709], [1436936973.683892, 216099, 7.849164009094238], [1436937454.138383, 226331, 7.784902572631836], [1436937933.838475, 236532, 7.749933242797852], [1436938413.89688, 246724, 7.777050971984863], [1436938894.018652, 256925, 7.663984775543213], [1436939373.69067, 267137, 7.602056980133057], [1436939853.673692, 277369, 7.539070129394531], [1436940333.651346, 287620, 7.575552463531494], [1436940813.599579, 297848, 7.47900390625], [1436941293.596313, 308088, 7.403858184814453], [1436941773.659172, 318362, 7.589539527893066], [1436942253.648479, 328621, 7.511919975280762], [1436942733.752284, 338892, 7.31054162979126], [1436943213.621881, 349144, 7.261094570159912], [1436943693.698743, 359399, 7.552957534790039], [1436944173.578463, 369649, 7.449452877044678], [1436944653.692217, 379912, 7.177209854125977], [1436945133.677298, 390180, 7.308793067932129], [1436945613.572411, 400445, 7.229344844818115], [1436946093.56123, 410703, 7.129981994628906], [1436946573.542364, 420958, 7.127549171447754], [1436947053.616578, 431216, 7.538583755493164], [1436947533.636973, 441483, 7.030594825744629], [1436948013.541574, 451751, 6.98097038269043], [1436948493.560223, 462015, 7.213271141052246], [1436948973.512541, 472260, 7.1727519035339355], [1436949453.550055, 482483, 6.985068321228027], [1436949933.828011, 492731, 7.051283836364746], [1436950413.603177, 502957, 7.082402229309082], [1436950893.563009, 513185, 7.1637864112854], [1436951373.620887, 523410, 7.193849086761475], [1436951853.61941, 533618, 7.1212921142578125], [1436952333.694447, 543828, 7.208009719848633], [1436952813.621004, 554042, 7.28671932220459], [1436953293.588156, 564251, 6.941026210784912], [1436953773.599734, 574464, 7.230144500732422], [1436954253.621309, 584672, 6.815900802612305], [1436954733.738119, 594882, 7.060589790344238], [1436955213.56617, 605091, 7.079995155334473], [1436955693.585366, 615296, 7.300849437713623], [1436956173.626395, 625501, 6.927395343780518], [1436956653.601937, 635705, 6.893837928771973], [1436957133.665878, 645915, 6.965301990509033], [1436957613.584762, 656116, 6.902514457702637], [1436958093.549783, 666331, 7.2444868087768555], [1436958573.646778, 676543, 6.784783840179443], [1436959053.585655, 686750, 6.800273418426514], [1436959533.679696, 696961, 6.743415355682373], [1436960013.633292, 707173, 7.012747764587402], [1436960493.578778, 717383, 6.548677921295166], [1436960973.596715, 727598, 6.638228416442871], [1436961453.625644, 737818, 6.884350776672363], [1436961933.740339, 748040, 6.797428607940674], [1436962413.573845, 758252, 6.815422058105469], [1436962893.610678, 768470, 6.7392377853393555], [1436963373.642878, 778674, 6.8375959396362305], [1436963853.558388, 788877, 6.7254252433776855], [1436964333.658419, 799099, 6.765130996704102], [1436964813.573319, 809289, 6.7060980796813965], [1436965293.542098, 819484, 6.63279390335083], [1436965773.545453, 829687, 6.587352752685547], [1436966253.586517, 839901, 6.4957275390625], [1436966733.639348, 850120, 6.765798091888428], [1436967213.697288, 860330, 6.681786060333252], [1436967693.617172, 870539, 6.696804523468018], [1436968173.593885, 880748, 6.571035385131836], [1436968653.560836, 890955, 6.29492712020874], [1436969133.676337, 901164, 6.679598331451416], [1436969613.506638, 911358, 6.548522472381592], [1436970093.595964, 921560, 6.585646629333496], [1436970573.541227, 931756, 6.589619159698486], [1436971053.624316, 941945, 6.333208084106445], [1436971533.655543, 952138, 6.582470417022705], [1436972013.604738, 962349, 6.289045810699463], [1436972493.613199, 972551, 6.360206127166748], [1436972973.501155, 982746, 6.567287921905518], [1436973453.64842, 992945, 6.246123313903809], [1436973933.689516, 1003147, 6.44004487991333], [1436974413.577769, 1013350, 6.315634727478027], [1436974893.542281, 1023545, 6.289544105529785], [1436975373.638453, 1033759, 6.412042140960693], [1436975853.524388, 1043955, 6.165371894836426], [1436976333.625792, 1054148, 6.403027534484863], [1436976813.610661, 1064342, 6.37597131729126], [1436977293.601581, 1074539, 6.336863994598389], [1436977773.575627, 1084733, 6.377552032470703], [1436978253.564972, 1094914, 6.28995943069458], [1436978733.673144, 1105109, 6.28420352935791], [1436979213.540585, 1115293, 6.277828216552734], [1436979693.699591, 1125483, 6.185207843780518], [1436980173.613012, 1135670, 6.186310768127441], [1436980653.575769, 1145862, 5.922095775604248], [1436981133.719264, 1156045, 6.141305923461914], [1436981613.563551, 1166236, 6.10508394241333], [1436982093.553233, 1176436, 5.967081069946289], [1436982573.577846, 1186636, 5.960882186889648], [1436983053.605749, 1196837, 6.2222185134887695], [1436983533.684994, 1207025, 6.051136493682861], [1436984013.561492, 1217233, 6.087917804718018], [1436984493.629873, 1227437, 5.95945405960083], [1436984973.606714, 1237643, 5.971570014953613], [1436985453.690084, 1247835, 5.969781398773193], [1436985933.711388, 1257951, 6.040994644165039], [1436986413.598807, 1268125, 6.142050743103027], [1436986893.631797, 1278290, 6.03120231628418], [1436987373.596962, 1288473, 5.921470642089844], [1436987853.555549, 1298650, 5.921937942504883], [1436988333.722032, 1308841, 6.050085067749023], [1436988813.55697, 1319018, 5.837893486022949], [1436989293.756905, 1329221, 5.927487850189209], [1436989773.665141, 1339417, 6.117348670959473], [1436990253.768302, 1349610, 6.052918434143066], [1436990733.708919, 1359759, 5.8977789878845215], [1436991213.663033, 1369914, 5.903198719024658], [1436991693.730925, 1380074, 5.85245418548584], [1436992173.751791, 1390224, 5.902153968811035], [1436992653.758682, 1400383, 5.822136878967285], [1436993133.835604, 1410542, 5.88037633895874], [1436993613.674655, 1420684, 5.778636932373047], [1436994093.747454, 1430832, 5.876591682434082], [1436994573.768973, 1440986, 6.196285724639893], [1436995053.666661, 1451174, 5.7718634605407715], [1436995533.83439, 1461345, 5.931266784667969], [1436996013.556996, 1471495, 5.9706597328186035], [1436996493.635477, 1481663, 5.589694023132324], [1436996973.668684, 1491822, 5.787637233734131], [1436997453.59326, 1501979, 5.634321689605713], [1436997933.774019, 1512139, 5.699962615966797], [1436998413.575162, 1522290, 5.807012557983398], [1436998893.640468, 1532431, 5.559602737426758], [1436999373.551661, 1542579, 5.918235778808594], [1436999853.57906, 1552734, 5.745569229125977], [1437000333.680409, 1562888, 5.59443473815918], [1437000813.602383, 1573037, 5.703190326690674], [1437001293.610337, 1583190, 5.468636512756348], [1437001773.618199, 1593341, 5.610755920410156], [1437002253.572966, 1603497, 5.4396867752075195], [1437002733.67994, 1613657, 5.7537946701049805], [1437003213.583266, 1623809, 5.7613725662231445], [1437003693.639943, 1633966, 5.439754009246826], [1437004173.568287, 1644113, 5.4889116287231445], [1437004653.610772, 1654268, 5.39843225479126], [1437005133.663045, 1664424, 5.576738357543945], [1437005613.580984, 1674567, 5.662004470825195], [1437006093.601019, 1684715, 5.3926777839660645], [1437006573.625314, 1694857, 5.464866638183594], [1437007053.584514, 1704999, 5.40261173248291], [1437007533.719303, 1715150, 5.23733377456665], [1437008013.604962, 1725282, 5.448479652404785], [1437008493.655091, 1735432, 5.684703826904297], [1437008973.640165, 1745584, 5.400024890899658], [1437009453.715067, 1755742, 5.378822326660156], [1437009933.765712, 1765896, 5.45297384262085], [1437010413.632128, 1776052, 5.248030185699463], [1437010893.66766, 1786195, 5.3377580642700195], [1437011373.636164, 1796346, 5.292956352233887], [1437011853.631224, 1806481, 5.438100814819336], [1437012333.706205, 1816617, 5.148743629455566], [1437012813.61987, 1826754, 5.319127559661865], [1437013293.479904, 1836883, 5.1646199226379395], [1437013773.604574, 1847029, 5.494720458984375], [1437014253.618884, 1857175, 5.17764949798584], [1437014733.756419, 1867312, 5.14331579208374], [1437015213.638607, 1877459, 5.309914588928223], [1437015693.625763, 1887608, 5.542352676391602], [1437016173.63194, 1897759, 5.075393199920654], [1437016653.609074, 1907909, 5.249225616455078], [1437017133.717601, 1918074, 5.392384052276611], [1437017613.716011, 1928220, 5.38590669631958], [1437018093.626005, 1938377, 5.229607105255127], [1437018573.626522, 1948523, 5.287610054016113], [1437019053.648174, 1958678, 5.2798333168029785], [1437019533.803011, 1968831, 5.151246070861816], [1437020013.667751, 1978978, 5.118294715881348], [1437020493.659028, 1989133, 5.327050685882568], [1437020973.657346, 1999287, 5.174264430999756], [1437021453.650634, 2009437, 5.1660661697387695], [1437021933.848661, 2019588, 5.089689254760742], [1437022413.674963, 2029736, 5.06661319732666], [1437022893.69086, 2039894, 5.031608581542969], [1437023373.68883, 2050054, 4.874476432800293], [1437023853.686116, 2060205, 5.107512474060059], [1437024333.763876, 2070362, 5.135380268096924], [1437024813.707845, 2080507, 5.087984561920166], [1437025293.483294, 2090645, 5.240448474884033], [1437025773.695712, 2100793, 4.930302619934082], [1437026253.672994, 2110943, 4.914392471313477], [1437026733.780775, 2121094, 5.182378768920898], [1437027213.617849, 2131235, 4.93843412399292], [1437027693.694451, 2141382, 4.924433708190918], [1437028173.68596, 2151537, 4.957921028137207], [1437028653.584833, 2161685, 5.040386199951172], [1437029133.792483, 2171839, 5.01956033706665], [1437029613.661672, 2181977, 4.987490177154541], [1437030093.641009, 2192118, 4.960195064544678], [1437030573.656274, 2202268, 5.0094523429870605], [1437031053.643631, 2212416, 4.83445930480957], [1437031533.777478, 2222583, 4.922268390655518], [1437032013.704008, 2232736, 5.113382339477539], [1437032493.638393, 2242882, 4.881488800048828], [1437032973.684986, 2253041, 4.953296661376953], [1437033453.699562, 2263183, 4.865671157836914], [1437033933.918074, 2273320, 4.829331874847412], [1437034413.596351, 2283443, 4.777036190032959], [1437034893.640496, 2293579, 4.864566326141357], [1437035373.637761, 2303701, 4.988693714141846], [1437035853.669947, 2313823, 5.016432285308838], [1437036333.78905, 2323961, 4.651939868927002], [1437036813.699727, 2334089, 4.767807960510254], [1437037293.662592, 2344235, 4.628738880157471], [1437037773.66716, 2354364, 4.929834842681885], [1437038253.603687, 2364507, 4.739555835723877], [1437038733.78864, 2374644, 4.821824073791504], [1437039213.641799, 2384782, 4.853730201721191], [1437039693.687078, 2394923, 4.581423759460449], [1437040173.635717, 2405058, 4.452754497528076], [1437040653.673331, 2415194, 4.837629318237305], [1437041133.764768, 2425322, 4.752482891082764], [1437041613.629279, 2435449, 4.730231761932373], [1437042093.703985, 2445575, 4.5618896484375], [1437042573.496029, 2455712, 4.673112869262695], [1437043053.686022, 2465844, 4.565918922424316], [1437043533.731929, 2475974, 4.7191481590271], [1437044013.636245, 2486095, 4.589008331298828], [1437044493.69923, 2496238, 4.599475383758545], [1437044973.652155, 2506373, 4.544175624847412], [1437045453.691467, 2516497, 4.4221673011779785], [1437045933.935804, 2526637, 4.44448709487915], [1437046413.635583, 2536770, 4.647110939025879], [1437046893.626337, 2546896, 4.768988609313965], [1437047373.67437, 2557029, 4.5318827629089355], [1437047853.652939, 2567169, 4.501277923583984], [1437048333.778436, 2577306, 4.6167216300964355], [1437048813.654248, 2587433, 4.66096305847168], [1437049293.610609, 2597552, 4.529193878173828], [1437049773.646573, 2607690, 4.455351829528809], [1437050253.667925, 2617808, 4.51211404800415], [1437050733.735291, 2627933, 4.803231716156006], [1437051213.620222, 2638053, 4.645476341247559], [1437051693.601978, 2648171, 4.419768810272217], [1437052173.634985, 2658299, 4.48175048828125], [1437052653.687176, 2668425, 4.397725582122803], [1437053133.762819, 2678556, 4.188413619995117], [1437053613.643698, 2688671, 4.291479110717773], [1437054093.673047, 2698804, 4.321218013763428], [1437054573.667371, 2708956, 4.311710834503174], [1437055053.650441, 2719087, 4.481810092926025], [1437055533.778469, 2729219, 4.452049255371094], [1437056013.694082, 2739343, 4.455989360809326], [1437056493.674871, 2749458, 4.415104866027832], [1437056973.700234, 2759575, 4.259828567504883], [1437057453.666129, 2769697, 4.510563373565674], [1437057933.848506, 2779821, 4.221935272216797], [1437058413.643799, 2789941, 4.437899112701416], [1437058893.715386, 2800076, 4.302872657775879], [1437059373.62596, 2810207, 4.228428363800049], [1437059853.650848, 2820334, 4.220061779022217], [1437060333.792248, 2830465, 4.138088703155518], [1437060813.682955, 2840600, 4.2196125984191895], [1437061293.681795, 2850745, 4.1594085693359375], [1437061773.691182, 2860880, 4.179514408111572], [1437062253.662987, 2871013, 4.202476978302002], [1437062733.760419, 2881153, 4.282044887542725], [1437063213.651969, 2891278, 4.200533866882324], [1437063693.723523, 2901406, 4.263350486755371], [1437064173.68663, 2911533, 4.378939628601074], [1437064653.547643, 2921667, 4.202810287475586], [1437065133.62645, 2931813, 4.193121910095215], [1437065613.566569, 2941947, 4.132870197296143], [1437066093.537804, 2952102, 4.35767936706543], [1437066573.529332, 2962243, 4.211732864379883], [1437067053.520098, 2972400, 4.020431041717529], [1437067533.605733, 2982561, 4.342063903808594], [1437068013.535467, 2992698, 4.197565078735352], [1437068493.559976, 3002839, 3.8806259632110596], [1437068973.558743, 3012983, 3.871702194213867], [1437069453.562661, 3023116, 4.064865589141846], [1437069933.627071, 3033256, 3.817744731903076], [1437070413.574131, 3043386, 4.106888294219971], [1437070893.658803, 3053528, 4.235474586486816], [1437071373.638711, 3063659, 4.127055644989014], [1437071853.621384, 3073794, 4.176018238067627], [1437072333.665269, 3083926, 4.048959732055664], [1437072813.584388, 3094040, 4.178991794586182], [1437073293.569178, 3104172, 3.8385396003723145]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d1.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d1.json deleted file mode 100644 index 27ff64e5ddb..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d1.json +++ /dev/null @@ -1 +0,0 @@ -[[1436925978.257845, 7, 1.009283951483666897], [1436926413.945391, 1476, 0.932567862421274185], [1436926893.945037, 6006, 0.02773338556289673], [1436927373.995472, 13786, 0.021291319280862808], [1436927853.989794, 23650, 0.515582754276692867], [1436928334.132361, 33755, 0.011689444072544575], [1436928813.973288, 43941, 0.009183925576508045], [1436929293.975949, 54146, 0.007850822061300278], [1436929773.992781, 64316, 0.007189035415649414], [1436930253.997415, 74465, 0.007230754010379314], [1436930734.203004, 84611, 0.007685001939535141], [1436931214.03644, 94700, 0.008264732547104359], [1436931694.094564, 104766, 0.008946491405367851], [1436932174.114955, 114817, 0.00966302677989006], [1436932654.161382, 124880, 0.010994276031851768], [1436933133.960214, 134977, 0.01196141354739666], [1436933614.044337, 145062, 0.012673594057559967], [1436934094.166206, 155169, 0.013639944605529308], [1436934574.106036, 165284, 0.014305333606898785], [1436935054.150647, 175402, 0.014946178533136845], [1436935533.819562, 185538, 0.015736915171146393], [1436936013.710422, 195712, 0.01633097417652607], [1436936493.609025, 205906, 0.01669587567448616], [1436936973.683892, 216099, 0.017459288239479065], [1436937454.138383, 226331, 0.018532060086727142], [1436937933.838475, 236532, 0.01949254982173443], [1436938413.89688, 246724, 0.01951725408434868], [1436938894.018652, 256925, 0.019763393327593803], [1436939373.69067, 267137, 0.02008610963821411], [1436939853.673692, 277369, 0.021090799942612648], [1436940333.651346, 287620, 0.021408839151263237], [1436940813.599579, 297848, 0.021988894790410995], [1436941293.596313, 308088, 0.02236073836684227], [1436941773.659172, 318362, 0.022547174245119095], [1436942253.648479, 328621, 0.02303086407482624], [1436942733.752284, 338892, 0.023787079378962517], [1436943213.621881, 349144, 0.024007514119148254], [1436943693.698743, 359399, 0.02414763905107975], [1436944173.578463, 369649, 0.024576496332883835], [1436944653.692217, 379912, 0.02469169721007347], [1436945133.677298, 390180, 0.024951916188001633], [1436945613.572411, 400445, 0.025548970326781273], [1436946093.56123, 410703, 0.025769377127289772], [1436946573.542364, 420958, 0.02602097950875759], [1436947053.616578, 431216, 0.026028109714388847], [1436947533.636973, 441483, 0.026348495855927467], [1436948013.541574, 451751, 0.02621930092573166], [1436948493.560223, 462015, 0.02671053633093834], [1436948973.512541, 472260, 0.0272178016602993], [1436949453.550055, 482483, 0.02734796144068241], [1436949933.828011, 492731, 0.027217809110879898], [1436950413.603177, 502957, 0.027318621054291725], [1436950893.563009, 513185, 0.027304155752062798], [1436951373.620887, 523410, 0.027759933844208717], [1436951853.61941, 533618, 0.028056234121322632], [1436952333.694447, 543828, 0.028620803728699684], [1436952813.621004, 554042, 0.028957637026906013], [1436953293.588156, 564251, 0.029187509790062904], [1436953773.599734, 574464, 0.028960268944501877], [1436954253.621309, 584672, 0.02891424670815468], [1436954733.738119, 594882, 0.029211293905973434], [1436955213.56617, 605091, 0.029444213956594467], [1436955693.585366, 615296, 0.02974688820540905], [1436956173.626395, 625501, 0.03026159666478634], [1436956653.601937, 635705, 0.03039497137069702], [1436957133.665878, 645915, 0.03041839227080345], [1436957613.584762, 656116, 0.030588043853640556], [1436958093.549783, 666331, 0.030284974724054337], [1436958573.646778, 676543, 0.030354496091604233], [1436959053.585655, 686750, 0.030551007017493248], [1436959533.679696, 696961, 0.03068561479449272], [1436960013.633292, 707173, 0.030921893194317818], [1436960493.578778, 717383, 0.031080031767487526], [1436960973.596715, 727598, 0.030773505568504333], [1436961453.625644, 737818, 0.03084484674036503], [1436961933.740339, 748040, 0.03110458515584469], [1436962413.573845, 758252, 0.03114113211631775], [1436962893.610678, 768470, 0.03101053647696972], [1436963373.642878, 778674, 0.03110116347670555], [1436963853.558388, 788877, 0.031342316418886185], [1436964333.658419, 799099, 0.03130127117037773], [1436964813.573319, 809289, 0.031288161873817444], [1436965293.542098, 819484, 0.031435444951057434], [1436965773.545453, 829687, 0.03166936710476875], [1436966253.586517, 839901, 0.03169429674744606], [1436966733.639348, 850120, 0.03191458433866501], [1436967213.697288, 860330, 0.03205746412277222], [1436967693.617172, 870539, 0.03206293657422066], [1436968173.593885, 880748, 0.031957853585481644], [1436968653.560836, 890955, 0.0316658616065979], [1436969133.676337, 901164, 0.031929533928632736], [1436969613.506638, 911358, 0.03174331784248352], [1436970093.595964, 921560, 0.03157960623502731], [1436970573.541227, 931756, 0.03176721930503845], [1436971053.624316, 941945, 0.031810544431209564], [1436971533.655543, 952138, 0.031946416944265366], [1436972013.604738, 962349, 0.03205405920743942], [1436972493.613199, 972551, 0.031924981623888016], [1436972973.501155, 982746, 0.03199697285890579], [1436973453.64842, 992945, 0.03204970061779022], [1436973933.689516, 1003147, 0.032020214945077896], [1436974413.577769, 1013350, 0.03207497298717499], [1436974893.542281, 1023545, 0.03221454098820686], [1436975373.638453, 1033759, 0.032191887497901917], [1436975853.524388, 1043955, 0.03240729123353958], [1436976333.625792, 1054148, 0.032219529151916504], [1436976813.610661, 1064342, 0.03200426697731018], [1436977293.601581, 1074539, 0.03198647499084473], [1436977773.575627, 1084733, 0.0320645235478878], [1436978253.564972, 1094914, 0.0322980061173439], [1436978733.673144, 1105109, 0.032482605427503586], [1436979213.540585, 1115293, 0.032628435641527176], [1436979693.699591, 1125483, 0.032744552940130234], [1436980173.613012, 1135670, 0.03268158435821533], [1436980653.575769, 1145862, 0.0324023962020874], [1436981133.719264, 1156045, 0.03237328305840492], [1436981613.563551, 1166236, 0.03202575817704201], [1436982093.553233, 1176436, 0.03216284513473511], [1436982573.577846, 1186636, 0.03232415020465851], [1436983053.605749, 1196837, 0.0324099175632], [1436983533.684994, 1207025, 0.03245137259364128], [1436984013.561492, 1217233, 0.032246463000774384], [1436984493.629873, 1227437, 0.032042667269706726], [1436984973.606714, 1237643, 0.0318642184138298], [1436985453.690084, 1247835, 0.03191140666604042], [1436985933.711388, 1257951, 0.032287366688251495], [1436986413.598807, 1268125, 0.03226638585329056], [1436986893.631797, 1278290, 0.03252791240811348], [1436987373.596962, 1288473, 0.03241675719618797], [1436987853.555549, 1298650, 0.032103829085826874], [1436988333.722032, 1308841, 0.031904906034469604], [1436988813.55697, 1319018, 0.03179024159908295], [1436989293.756905, 1329221, 0.03168707340955734], [1436989773.665141, 1339417, 0.03160175681114197], [1436990253.768302, 1349610, 0.03161788731813431], [1436990733.708919, 1359759, 0.031772397458553314], [1436991213.663033, 1369914, 0.031758904457092285], [1436991693.730925, 1380074, 0.031629469245672226], [1436992173.751791, 1390224, 0.03154703974723816], [1436992653.758682, 1400383, 0.031527940183877945], [1436993133.835604, 1410542, 0.03169580549001694], [1436993613.674655, 1420684, 0.03182605654001236], [1436994093.747454, 1430832, 0.03185024857521057], [1436994573.768973, 1440986, 0.03199737146496773], [1436995053.666661, 1451174, 0.03156095743179321], [1436995533.83439, 1461345, 0.03150693327188492], [1436996013.556996, 1471495, 0.031496383249759674], [1436996493.635477, 1481663, 0.0313432440161705], [1436996973.668684, 1491822, 0.031145794317126274], [1436997453.59326, 1501979, 0.03106667660176754], [1436997933.774019, 1512139, 0.03143244981765747], [1436998413.575162, 1522290, 0.03142988309264183], [1436998893.640468, 1532431, 0.03132546320557594], [1436999373.551661, 1542579, 0.03125471621751785], [1436999853.57906, 1552734, 0.03098788857460022], [1437000333.680409, 1562888, 0.0308846328407526], [1437000813.602383, 1573037, 0.03082612156867981], [1437001293.610337, 1583190, 0.030793681740760803], [1437001773.618199, 1593341, 0.03087364137172699], [1437002253.572966, 1603497, 0.030839646235108376], [1437002733.67994, 1613657, 0.030705047771334648], [1437003213.583266, 1623809, 0.03071814589202404], [1437003693.639943, 1633966, 0.0304812490940094], [1437004173.568287, 1644113, 0.03030412085354328], [1437004653.610772, 1654268, 0.03032425045967102], [1437005133.663045, 1664424, 0.030430471524596214], [1437005613.580984, 1674567, 0.03036225587129593], [1437006093.601019, 1684715, 0.03056645393371582], [1437006573.625314, 1694857, 0.03043070062994957], [1437007053.584514, 1704999, 0.030224520713090897], [1437007533.719303, 1715150, 0.03024231642484665], [1437008013.604962, 1725282, 0.03009769506752491], [1437008493.655091, 1735432, 0.030214866623282433], [1437008973.640165, 1745584, 0.030181538313627243], [1437009453.715067, 1755742, 0.03017231822013855], [1437009933.765712, 1765896, 0.030141284689307213], [1437010413.632128, 1776052, 0.030052203685045242], [1437010893.66766, 1786195, 0.030078601092100143], [1437011373.636164, 1796346, 0.029969291761517525], [1437011853.631224, 1806481, 0.02999536693096161], [1437012333.706205, 1816617, 0.030100464820861816], [1437012813.61987, 1826754, 0.03008824959397316], [1437013293.479904, 1836883, 0.029995709657669067], [1437013773.604574, 1847029, 0.02995096519589424], [1437014253.618884, 1857175, 0.02980179339647293], [1437014733.756419, 1867312, 0.029607007279992104], [1437015213.638607, 1877459, 0.02952035330235958], [1437015693.625763, 1887608, 0.02937002293765545], [1437016173.63194, 1897759, 0.029285306110978127], [1437016653.609074, 1907909, 0.029194746166467667], [1437017133.717601, 1918074, 0.029153630137443542], [1437017613.716011, 1928220, 0.029063496738672256], [1437018093.626005, 1938377, 0.028990253806114197], [1437018573.626522, 1948523, 0.0290801040828228], [1437019053.648174, 1958678, 0.029026925563812256], [1437019533.803011, 1968831, 0.029071522876620293], [1437020013.667751, 1978978, 0.02911040186882019], [1437020493.659028, 1989133, 0.02908971533179283], [1437020973.657346, 1999287, 0.028982823714613914], [1437021453.650634, 2009437, 0.028793631121516228], [1437021933.848661, 2019588, 0.02868799678981304], [1437022413.674963, 2029736, 0.028585929423570633], [1437022893.69086, 2039894, 0.028488371521234512], [1437023373.68883, 2050054, 0.028293771669268608], [1437023853.686116, 2060205, 0.028227869421243668], [1437024333.763876, 2070362, 0.0280953086912632], [1437024813.707845, 2080507, 0.02794187143445015], [1437025293.483294, 2090645, 0.0278786551207304], [1437025773.695712, 2100793, 0.02786232903599739], [1437026253.672994, 2110943, 0.02783624827861786], [1437026733.780775, 2121094, 0.027756746858358383], [1437027213.617849, 2131235, 0.027644069865345955], [1437027693.694451, 2141382, 0.02752004750072956], [1437028173.68596, 2151537, 0.0274327602237463], [1437028653.584833, 2161685, 0.027434347197413445], [1437029133.792483, 2171839, 0.02731819450855255], [1437029613.661672, 2181977, 0.027138520032167435], [1437030093.641009, 2192118, 0.027088932693004608], [1437030573.656274, 2202268, 0.02713087759912014], [1437031053.643631, 2212416, 0.027159670367836952], [1437031533.777478, 2222583, 0.027089878916740417], [1437032013.704008, 2232736, 0.026989545673131943], [1437032493.638393, 2242882, 0.02692277729511261], [1437032973.684986, 2253041, 0.026783647015690804], [1437033453.699562, 2263183, 0.026735099032521248], [1437033933.918074, 2273320, 0.02665248140692711], [1437034413.596351, 2283443, 0.02659791149199009], [1437034893.640496, 2293579, 0.026540575549006462], [1437035373.637761, 2303701, 0.02647154964506626], [1437035853.669947, 2313823, 0.02645135670900345], [1437036333.78905, 2323961, 0.026429900899529457], [1437036813.699727, 2334089, 0.026324935257434845], [1437037293.662592, 2344235, 0.026287639513611794], [1437037773.66716, 2354364, 0.02626391313970089], [1437038253.603687, 2364507, 0.026225272566080093], [1437038733.78864, 2374644, 0.026248561218380928], [1437039213.641799, 2384782, 0.026243599131703377], [1437039693.687078, 2394923, 0.026255469769239426], [1437040173.635717, 2405058, 0.026186810806393623], [1437040653.673331, 2415194, 0.02606010064482689], [1437041133.764768, 2425322, 0.026031550019979477], [1437041613.629279, 2435449, 0.02595149166882038], [1437042093.703985, 2445575, 0.025885630398988724], [1437042573.496029, 2455712, 0.025858554989099503], [1437043053.686022, 2465844, 0.0257696695625782], [1437043533.731929, 2475974, 0.02574242651462555], [1437044013.636245, 2486095, 0.025741754099726677], [1437044493.69923, 2496238, 0.02561314031481743], [1437044973.652155, 2506373, 0.02550213597714901], [1437045453.691467, 2516497, 0.025422468781471252], [1437045933.935804, 2526637, 0.025300107896327972], [1437046413.635583, 2536770, 0.02533198893070221], [1437046893.626337, 2546896, 0.025261884555220604], [1437047373.67437, 2557029, 0.025176096707582474], [1437047853.652939, 2567169, 0.025054505094885826], [1437048333.778436, 2577306, 0.024978378787636757], [1437048813.654248, 2587433, 0.024952610954642296], [1437049293.610609, 2597552, 0.02484666183590889], [1437049773.646573, 2607690, 0.024764036759734154], [1437050253.667925, 2617808, 0.024689028039574623], [1437050733.735291, 2627933, 0.024599267169833183], [1437051213.620222, 2638053, 0.024585112929344177], [1437051693.601978, 2648171, 0.024474989622831345], [1437052173.634985, 2658299, 0.024343013763427734], [1437052653.687176, 2668425, 0.024294432252645493], [1437053133.762819, 2678556, 0.024164099246263504], [1437053613.643698, 2688671, 0.024035055190324783], [1437054093.673047, 2698804, 0.024000361561775208], [1437054573.667371, 2708956, 0.023914529010653496], [1437055053.650441, 2719087, 0.023955287411808968], [1437055533.778469, 2729219, 0.023859601467847824], [1437056013.694082, 2739343, 0.023759596049785614], [1437056493.674871, 2749458, 0.02367720566689968], [1437056973.700234, 2759575, 0.023645451292395592], [1437057453.666129, 2769697, 0.023565715178847313], [1437057933.848506, 2779821, 0.023514313623309135], [1437058413.643799, 2789941, 0.023489659652113914], [1437058893.715386, 2800076, 0.023429812863469124], [1437059373.62596, 2810207, 0.023344023153185844], [1437059853.650848, 2820334, 0.023226741701364517], [1437060333.792248, 2830465, 0.023134270682930946], [1437060813.682955, 2840600, 0.02305578999221325], [1437061293.681795, 2850745, 0.02298513427376747], [1437061773.691182, 2860880, 0.022913720458745956], [1437062253.662987, 2871013, 0.022864067927002907], [1437062733.760419, 2881153, 0.02278953418135643], [1437063213.651969, 2891278, 0.02276339940726757], [1437063693.723523, 2901406, 0.022675812244415283], [1437064173.68663, 2911533, 0.022622767835855484], [1437064653.547643, 2921667, 0.02255198359489441], [1437065133.62645, 2931813, 0.022431762889027596], [1437065613.566569, 2941947, 0.022368362173438072], [1437066093.537804, 2952102, 0.022323831915855408], [1437066573.529332, 2962243, 0.02226843684911728], [1437067053.520098, 2972400, 0.022210361436009407], [1437067533.605733, 2982561, 0.022118505090475082], [1437068013.535467, 2992698, 0.022013112902641296], [1437068493.559976, 3002839, 0.02197197824716568], [1437068973.558743, 3012983, 0.02191166952252388], [1437069453.562661, 3023116, 0.021851476281881332], [1437069933.627071, 3033256, 0.021762533113360405], [1437070413.574131, 3043386, 0.021733969449996948], [1437070893.658803, 3053528, 0.021669406443834305], [1437071373.638711, 3063659, 0.02159426547586918], [1437071853.621384, 3073794, 0.02153114229440689], [1437072333.665269, 3083926, 0.021499117836356163], [1437072813.584388, 3094040, 0.021457014605402946], [1437073293.569178, 3104172, 0.021365314722061157]] diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d2.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d2.json deleted file mode 100644 index fb5a18d53a1..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d2.json +++ /dev/null @@ -1 +0,0 @@ -[[1436925978.257845, 7, 0.01034154836088419], [1436926413.945391, 1476, 0.03646053001284599], [1436926893.945037, 6006, 0.031110260635614395], [1436927373.995472, 13786, 0.024214591830968857], [1436927853.989794, 23650, 0.01820789836347103], [1436928334.132361, 33755, 0.01442798599600792], [1436928813.973288, 43941, 0.012150184251368046], [1436929293.975949, 54146, 0.011141776107251644], [1436929773.992781, 64316, 0.010859030298888683], [1436930253.997415, 74465, 0.011160558089613914], [1436930734.203004, 84611, 0.011997541412711143], [1436931214.03644, 94700, 0.01278648804873228], [1436931694.094564, 104766, 0.014073861762881279], [1436932174.114955, 114817, 0.01523376815021038], [1436932654.161382, 124880, 0.016527879983186722], [1436933133.960214, 134977, 0.01782997138798237], [1436933614.044337, 145062, 0.019055265933275223], [1436934094.166206, 155169, 0.02028629370033741], [1436934574.106036, 165284, 0.02116803079843521], [1436935054.150647, 175402, 0.022192901000380516], [1436935533.819562, 185538, 0.022869590669870377], [1436936013.710422, 195712, 0.023398980498313904], [1436936493.609025, 205906, 0.02443159930408001], [1436936973.683892, 216099, 0.025154944509267807], [1436937454.138383, 226331, 0.025802481919527054], [1436937933.838475, 236532, 0.027000702917575836], [1436938413.89688, 246724, 0.02752412110567093], [1436938894.018652, 256925, 0.0278119258582592], [1436939373.69067, 267137, 0.027698883786797523], [1436939853.673692, 277369, 0.028744956478476524], [1436940333.651346, 287620, 0.029281964525580406], [1436940813.599579, 297848, 0.03002205118536949], [1436941293.596313, 308088, 0.030467400327324867], [1436941773.659172, 318362, 0.03132195770740509], [1436942253.648479, 328621, 0.031431782990694046], [1436942733.752284, 338892, 0.03147844970226288], [1436943213.621881, 349144, 0.032013144344091415], [1436943693.698743, 359399, 0.03241390734910965], [1436944173.578463, 369649, 0.03261363133788109], [1436944653.692217, 379912, 0.033306822180747986], [1436945133.677298, 390180, 0.03390969708561897], [1436945613.572411, 400445, 0.03396527096629143], [1436946093.56123, 410703, 0.03388286381959915], [1436946573.542364, 420958, 0.03399669751524925], [1436947053.616578, 431216, 0.03394070267677307], [1436947533.636973, 441483, 0.03419327735900879], [1436948013.541574, 451751, 0.0342416949570179], [1436948493.560223, 462015, 0.034808479249477386], [1436948973.512541, 472260, 0.03552314639091492], [1436949453.550055, 482483, 0.036012329161167145], [1436949933.828011, 492731, 0.035826291888952255], [1436950413.603177, 502957, 0.03600003197789192], [1436950893.563009, 513185, 0.03563224524259567], [1436951373.620887, 523410, 0.03584449738264084], [1436951853.61941, 533618, 0.03587675839662552], [1436952333.694447, 543828, 0.036698292940855026], [1436952813.621004, 554042, 0.03698749095201492], [1436953293.588156, 564251, 0.03712376952171326], [1436953773.599734, 574464, 0.03729996830224991], [1436954253.621309, 584672, 0.03730553761124611], [1436954733.738119, 594882, 0.037479378283023834], [1436955213.56617, 605091, 0.03754287213087082], [1436955693.585366, 615296, 0.0377657376229763], [1436956173.626395, 625501, 0.038117796182632446], [1436956653.601937, 635705, 0.03822959586977959], [1436957133.665878, 645915, 0.03776161000132561], [1436957613.584762, 656116, 0.03816362842917442], [1436958093.549783, 666331, 0.03853853791952133], [1436958573.646778, 676543, 0.03826189786195755], [1436959053.585655, 686750, 0.0381099209189415], [1436959533.679696, 696961, 0.03844142332673073], [1436960013.633292, 707173, 0.03868117928504944], [1436960493.578778, 717383, 0.0390009842813015], [1436960973.596715, 727598, 0.0383562371134758], [1436961453.625644, 737818, 0.0382055900990963], [1436961933.740339, 748040, 0.03806299716234207], [1436962413.573845, 758252, 0.03807120397686958], [1436962893.610678, 768470, 0.03795558586716652], [1436963373.642878, 778674, 0.038018494844436646], [1436963853.558388, 788877, 0.038447774946689606], [1436964333.658419, 799099, 0.03842216357588768], [1436964813.573319, 809289, 0.03840547427535057], [1436965293.542098, 819484, 0.038492728024721146], [1436965773.545453, 829687, 0.0387515053153038], [1436966253.586517, 839901, 0.03869732841849327], [1436966733.639348, 850120, 0.03907460719347], [1436967213.697288, 860330, 0.0395859070122242], [1436967693.617172, 870539, 0.039280518889427185], [1436968173.593885, 880748, 0.0392826572060585], [1436968653.560836, 890955, 0.03899630531668663], [1436969133.676337, 901164, 0.03888440132141113], [1436969613.506638, 911358, 0.038790252059698105], [1436970093.595964, 921560, 0.03851785138249397], [1436970573.541227, 931756, 0.03913348540663719], [1436971053.624316, 941945, 0.038978900760412216], [1436971533.655543, 952138, 0.03925086557865143], [1436972013.604738, 962349, 0.039124101400375366], [1436972493.613199, 972551, 0.0390220545232296], [1436972973.501155, 982746, 0.039025235921144485], [1436973453.64842, 992945, 0.03877083212137222], [1436973933.689516, 1003147, 0.03902769833803177], [1436974413.577769, 1013350, 0.038719139993190765], [1436974893.542281, 1023545, 0.03872331231832504], [1436975373.638453, 1033759, 0.03927341103553772], [1436975853.524388, 1043955, 0.03930830955505371], [1436976333.625792, 1054148, 0.039153918623924255], [1436976813.610661, 1064342, 0.03932590410113335], [1436977293.601581, 1074539, 0.03922765702009201], [1436977773.575627, 1084733, 0.039390794932842255], [1436978253.564972, 1094914, 0.03935663774609566], [1436978733.673144, 1105109, 0.03939087316393852], [1436979213.540585, 1115293, 0.039371199905872345], [1436979693.699591, 1125483, 0.03982992097735405], [1436980173.613012, 1135670, 0.03941287472844124], [1436980653.575769, 1145862, 0.03933672979474068], [1436981133.719264, 1156045, 0.03919614478945732], [1436981613.563551, 1166236, 0.03906407952308655], [1436982093.553233, 1176436, 0.038837045431137085], [1436982573.577846, 1186636, 0.039009105414152145], [1436983053.605749, 1196837, 0.039010051637887955], [1436983533.684994, 1207025, 0.03891472890973091], [1436984013.561492, 1217233, 0.038610219955444336], [1436984493.629873, 1227437, 0.03866511583328247], [1436984973.606714, 1237643, 0.03865685313940048], [1436985453.690084, 1247835, 0.038945719599723816], [1436985933.711388, 1257951, 0.03925580158829689], [1436986413.598807, 1268125, 0.039332933723926544], [1436986893.631797, 1278290, 0.03918297216296196], [1436987373.596962, 1288473, 0.03883613646030426], [1436987853.555549, 1298650, 0.038776978850364685], [1436988333.722032, 1308841, 0.03888171166181564], [1436988813.55697, 1319018, 0.038825325667858124], [1436989293.756905, 1329221, 0.03864298388361931], [1436989773.665141, 1339417, 0.03865634649991989], [1436990253.768302, 1349610, 0.03898858651518822], [1436990733.708919, 1359759, 0.03906260430812836], [1436991213.663033, 1369914, 0.03911694139242172], [1436991693.730925, 1380074, 0.03875250369310379], [1436992173.751791, 1390224, 0.03882621228694916], [1436992653.758682, 1400383, 0.03877855837345123], [1436993133.835604, 1410542, 0.03870398923754692], [1436993613.674655, 1420684, 0.03887751325964928], [1436994093.747454, 1430832, 0.03915301710367203], [1436994573.768973, 1440986, 0.03938450664281845], [1436995053.666661, 1451174, 0.03919720649719238], [1436995533.83439, 1461345, 0.038862887769937515], [1436996013.556996, 1471495, 0.03901274502277374], [1436996493.635477, 1481663, 0.0388539656996727], [1436996973.668684, 1491822, 0.038732752203941345], [1436997453.59326, 1501979, 0.03879735246300697], [1436997933.774019, 1512139, 0.038524042814970016], [1436998413.575162, 1522290, 0.03869651257991791], [1436998893.640468, 1532431, 0.0383637398481369], [1436999373.551661, 1542579, 0.038300249725580215], [1436999853.57906, 1552734, 0.03799160569906235], [1437000333.680409, 1562888, 0.03759683296084404], [1437000813.602383, 1573037, 0.037678662687540054], [1437001293.610337, 1583190, 0.037575822323560715], [1437001773.618199, 1593341, 0.0376887246966362], [1437002253.572966, 1603497, 0.037922415882349014], [1437002733.67994, 1613657, 0.03766244649887085], [1437003213.583266, 1623809, 0.03754705190658569], [1437003693.639943, 1633966, 0.03738937899470329], [1437004173.568287, 1644113, 0.037347543984651566], [1437004653.610772, 1654268, 0.037374842911958694], [1437005133.663045, 1664424, 0.037443988025188446], [1437005613.580984, 1674567, 0.037457264959812164], [1437006093.601019, 1684715, 0.037874478846788406], [1437006573.625314, 1694857, 0.037644676864147186], [1437007053.584514, 1704999, 0.03743988648056984], [1437007533.719303, 1715150, 0.03739031031727791], [1437008013.604962, 1725282, 0.037301771342754364], [1437008493.655091, 1735432, 0.03735104575753212], [1437008973.640165, 1745584, 0.037282250821590424], [1437009453.715067, 1755742, 0.03729768097400665], [1437009933.765712, 1765896, 0.03717759624123573], [1437010413.632128, 1776052, 0.03691410645842552], [1437010893.66766, 1786195, 0.036807890981435776], [1437011373.636164, 1796346, 0.036659423261880875], [1437011853.631224, 1806481, 0.03682238608598709], [1437012333.706205, 1816617, 0.036776404827833176], [1437012813.61987, 1826754, 0.036672260612249374], [1437013293.479904, 1836883, 0.03666841238737106], [1437013773.604574, 1847029, 0.036642514169216156], [1437014253.618884, 1857175, 0.03654393553733826], [1437014733.756419, 1867312, 0.03638240322470665], [1437015213.638607, 1877459, 0.03610989451408386], [1437015693.625763, 1887608, 0.036011870950460434], [1437016173.63194, 1897759, 0.03607400134205818], [1437016653.609074, 1907909, 0.03581620752811432], [1437017133.717601, 1918074, 0.035680998116731644], [1437017613.716011, 1928220, 0.03547567501664162], [1437018093.626005, 1938377, 0.035375215113162994], [1437018573.626522, 1948523, 0.03534447029232979], [1437019053.648174, 1958678, 0.03535373508930206], [1437019533.803011, 1968831, 0.03541970252990723], [1437020013.667751, 1978978, 0.03534942492842674], [1437020493.659028, 1989133, 0.035337116569280624], [1437020973.657346, 1999287, 0.03519223630428314], [1437021453.650634, 2009437, 0.0350094810128212], [1437021933.848661, 2019588, 0.03481736779212952], [1437022413.674963, 2029736, 0.03482922539114952], [1437022893.69086, 2039894, 0.03482965752482414], [1437023373.68883, 2050054, 0.034710027277469635], [1437023853.686116, 2060205, 0.03447446599602699], [1437024333.763876, 2070362, 0.034356746822595596], [1437024813.707845, 2080507, 0.03430519998073578], [1437025293.483294, 2090645, 0.03412580490112305], [1437025773.695712, 2100793, 0.03409077599644661], [1437026253.672994, 2110943, 0.0340830534696579], [1437026733.780775, 2121094, 0.03400549292564392], [1437027213.617849, 2131235, 0.033846043050289154], [1437027693.694451, 2141382, 0.03379584103822708], [1437028173.68596, 2151537, 0.033618565648794174], [1437028653.584833, 2161685, 0.03352222591638565], [1437029133.792483, 2171839, 0.03338197246193886], [1437029613.661672, 2181977, 0.03323192894458771], [1437030093.641009, 2192118, 0.03313163295388222], [1437030573.656274, 2202268, 0.0331595316529274], [1437031053.643631, 2212416, 0.03310840204358101], [1437031533.777478, 2222583, 0.03298124670982361], [1437032013.704008, 2232736, 0.03288085386157036], [1437032493.638393, 2242882, 0.03281677886843681], [1437032973.684986, 2253041, 0.03261971473693848], [1437033453.699562, 2263183, 0.03251069411635399], [1437033933.918074, 2273320, 0.03243493288755417], [1437034413.596351, 2283443, 0.03251812607049942], [1437034893.640496, 2293579, 0.03244208171963692], [1437035373.637761, 2303701, 0.03246922418475151], [1437035853.669947, 2313823, 0.032652080059051514], [1437036333.78905, 2323961, 0.032621122896671295], [1437036813.699727, 2334089, 0.03248974680900574], [1437037293.662592, 2344235, 0.032404426485300064], [1437037773.66716, 2354364, 0.03240393102169037], [1437038253.603687, 2364507, 0.03238365799188614], [1437038733.78864, 2374644, 0.03244389593601227], [1437039213.641799, 2384782, 0.03239350765943527], [1437039693.687078, 2394923, 0.032426562160253525], [1437040173.635717, 2405058, 0.032403264194726944], [1437040653.673331, 2415194, 0.03231978043913841], [1437041133.764768, 2425322, 0.03223187103867531], [1437041613.629279, 2435449, 0.03213196247816086], [1437042093.703985, 2445575, 0.032153598964214325], [1437042573.496029, 2455712, 0.03199320286512375], [1437043053.686022, 2465844, 0.03188605234026909], [1437043533.731929, 2475974, 0.03178738057613373], [1437044013.636245, 2486095, 0.03171614184975624], [1437044493.69923, 2496238, 0.031645938754081726], [1437044973.652155, 2506373, 0.03155189007520676], [1437045453.691467, 2516497, 0.03144536912441254], [1437045933.935804, 2526637, 0.031432293355464935], [1437046413.635583, 2536770, 0.03129834309220314], [1437046893.626337, 2546896, 0.031195342540740967], [1437047373.67437, 2557029, 0.031033318489789963], [1437047853.652939, 2567169, 0.030938012525439262], [1437048333.778436, 2577306, 0.030827201902866364], [1437048813.654248, 2587433, 0.03068169392645359], [1437049293.610609, 2597552, 0.030520914122462273], [1437049773.646573, 2607690, 0.030437452718615532], [1437050253.667925, 2617808, 0.03041636385023594], [1437050733.735291, 2627933, 0.030291059985756874], [1437051213.620222, 2638053, 0.030283397063612938], [1437051693.601978, 2648171, 0.030193043872714043], [1437052173.634985, 2658299, 0.03004123829305172], [1437052653.687176, 2668425, 0.0299222432076931], [1437053133.762819, 2678556, 0.029762346297502518], [1437053613.643698, 2688671, 0.02970775216817856], [1437054093.673047, 2698804, 0.029604140669107437], [1437054573.667371, 2708956, 0.02949359640479088], [1437055053.650441, 2719087, 0.02943229116499424], [1437055533.778469, 2729219, 0.029304414987564087], [1437056013.694082, 2739343, 0.029147598892450333], [1437056493.674871, 2749458, 0.029033908620476723], [1437056973.700234, 2759575, 0.028886595740914345], [1437057453.666129, 2769697, 0.028734514489769936], [1437057933.848506, 2779821, 0.02874554693698883], [1437058413.643799, 2789941, 0.028716085478663445], [1437058893.715386, 2800076, 0.028669510036706924], [1437059373.62596, 2810207, 0.028530430048704147], [1437059853.650848, 2820334, 0.02839958481490612], [1437060333.792248, 2830465, 0.028364405035972595], [1437060813.682955, 2840600, 0.0282796248793602], [1437061293.681795, 2850745, 0.02820495329797268], [1437061773.691182, 2860880, 0.028159918263554573], [1437062253.662987, 2871013, 0.028104742988944054], [1437062733.760419, 2881153, 0.028099438175559044], [1437063213.651969, 2891278, 0.02802356891334057], [1437063693.723523, 2901406, 0.027945902198553085], [1437064173.68663, 2911533, 0.027897505089640617], [1437064653.547643, 2921667, 0.027821676805615425], [1437065133.62645, 2931813, 0.02770490199327469], [1437065613.566569, 2941947, 0.02761264331638813], [1437066093.537804, 2952102, 0.027557073161005974], [1437066573.529332, 2962243, 0.027522796764969826], [1437067053.520098, 2972400, 0.027469975873827934], [1437067533.605733, 2982561, 0.027299631386995316], [1437068013.535467, 2992698, 0.027225365862250328], [1437068493.559976, 3002839, 0.027095869183540344], [1437068973.558743, 3012983, 0.027036350220441818], [1437069453.562661, 3023116, 0.02693818509578705], [1437069933.627071, 3033256, 0.02687198854982853], [1437070413.574131, 3043386, 0.02687297947704792], [1437070893.658803, 3053528, 0.026770537719130516], [1437071373.638711, 3063659, 0.026667704805731773], [1437071853.621384, 3073794, 0.026571234688162804], [1437072333.665269, 3083926, 0.026447603479027748], [1437072813.584388, 3094040, 0.026389220729470253], [1437073293.569178, 3104172, 0.026299258694052696]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d3.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d3.json deleted file mode 100644 index e489130ea77..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d3.json +++ /dev/null @@ -1 +0,0 @@ -[[1436925978.257845, 7, 0.03425809368491173], [1436926413.945391, 1476, 0.032557398080825806], [1436926893.945037, 6006, 0.0277252234518528], [1436927373.995472, 13786, 0.021282576024532318], [1436927853.989794, 23650, 0.015578101389110088], [1436928334.132361, 33755, 0.011687012389302254], [1436928813.973288, 43941, 0.00918175745755434], [1436929293.975949, 54146, 0.00784988235682249], [1436929773.992781, 64316, 0.007188988849520683], [1436930253.997415, 74465, 0.0072308750823140144], [1436930734.203004, 84611, 0.007685060612857342], [1436931214.03644, 94700, 0.008267422206699848], [1436931694.094564, 104766, 0.008946981281042099], [1436932174.114955, 114817, 0.009664506651461124], [1436932654.161382, 124880, 0.010994983837008476], [1436933133.960214, 134977, 0.011961394920945168], [1436933614.044337, 145062, 0.012674711644649506], [1436934094.166206, 155169, 0.013640021905303001], [1436934574.106036, 165284, 0.014305224642157555], [1436935054.150647, 175402, 0.014946703799068928], [1436935533.819562, 185538, 0.015737954527139664], [1436936013.710422, 195712, 0.016330912709236145], [1436936493.609025, 205906, 0.016695979982614517], [1436936973.683892, 216099, 0.017458846792578697], [1436937454.138383, 226331, 0.018533164635300636], [1436937933.838475, 236532, 0.01949200965464115], [1436938413.89688, 246724, 0.019517479464411736], [1436938894.018652, 256925, 0.019764307886362076], [1436939373.69067, 267137, 0.02008572220802307], [1436939853.673692, 277369, 0.021091068163514137], [1436940333.651346, 287620, 0.02140945754945278], [1436940813.599579, 297848, 0.021988170221447945], [1436941293.596313, 308088, 0.0223606675863266], [1436941773.659172, 318362, 0.022547796368598938], [1436942253.648479, 328621, 0.023031413555145264], [1436942733.752284, 338892, 0.023786410689353943], [1436943213.621881, 349144, 0.024008480831980705], [1436943693.698743, 359399, 0.024148935452103615], [1436944173.578463, 369649, 0.02457556128501892], [1436944653.692217, 379912, 0.02469060942530632], [1436945133.677298, 390180, 0.024952523410320282], [1436945613.572411, 400445, 0.02554873190820217], [1436946093.56123, 410703, 0.025771528482437134], [1436946573.542364, 420958, 0.02602078951895237], [1436947053.616578, 431216, 0.02602880820631981], [1436947533.636973, 441483, 0.026351822540163994], [1436948013.541574, 451751, 0.0262188371270895], [1436948493.560223, 462015, 0.026711203157901764], [1436948973.512541, 472260, 0.027218565344810486], [1436949453.550055, 482483, 0.02734719216823578], [1436949933.828011, 492731, 0.027217986062169075], [1436950413.603177, 502957, 0.027318857610225677], [1436950893.563009, 513185, 0.027305351570248604], [1436951373.620887, 523410, 0.027760380879044533], [1436951853.61941, 533618, 0.0280567966401577], [1436952333.694447, 543828, 0.028621215373277664], [1436952813.621004, 554042, 0.028958816081285477], [1436953293.588156, 564251, 0.029186993837356567], [1436953773.599734, 574464, 0.028960207477211952], [1436954253.621309, 584672, 0.028913332149386406], [1436954733.738119, 594882, 0.02921229600906372], [1436955213.56617, 605091, 0.029444556683301926], [1436955693.585366, 615296, 0.029747728258371353], [1436956173.626395, 625501, 0.030260732397437096], [1436956653.601937, 635705, 0.030394721776247025], [1436957133.665878, 645915, 0.03041674755513668], [1436957613.584762, 656116, 0.03058660589158535], [1436958093.549783, 666331, 0.030284838750958443], [1436958573.646778, 676543, 0.030354052782058716], [1436959053.585655, 686750, 0.030551131814718246], [1436959533.679696, 696961, 0.030686482787132263], [1436960013.633292, 707173, 0.030921922996640205], [1436960493.578778, 717383, 0.031079748645424843], [1436960973.596715, 727598, 0.03077232837677002], [1436961453.625644, 737818, 0.03084420971572399], [1436961933.740339, 748040, 0.03110562451183796], [1436962413.573845, 758252, 0.031141508370637894], [1436962893.610678, 768470, 0.031010067090392113], [1436963373.642878, 778674, 0.031100917607545853], [1436963853.558388, 788877, 0.03134296461939812], [1436964333.658419, 799099, 0.031301673501729965], [1436964813.573319, 809289, 0.031290579587221146], [1436965293.542098, 819484, 0.031435515731573105], [1436965773.545453, 829687, 0.031667787581682205], [1436966253.586517, 839901, 0.03169453889131546], [1436966733.639348, 850120, 0.03191617131233215], [1436967213.697288, 860330, 0.03205711767077446], [1436967693.617172, 870539, 0.03206227719783783], [1436968173.593885, 880748, 0.03195691108703613], [1436968653.560836, 890955, 0.03166574612259865], [1436969133.676337, 901164, 0.031929291784763336], [1436969613.506638, 911358, 0.031744007021188736], [1436970093.595964, 921560, 0.0315803587436676], [1436970573.541227, 931756, 0.031766779720783234], [1436971053.624316, 941945, 0.03181062266230583], [1436971533.655543, 952138, 0.0319465771317482], [1436972013.604738, 962349, 0.032054755836725235], [1436972493.613199, 972551, 0.03192495182156563], [1436972973.501155, 982746, 0.0319976881146431], [1436973453.64842, 992945, 0.03205036744475365], [1436973933.689516, 1003147, 0.032020118087530136], [1436974413.577769, 1013350, 0.03207429125905037], [1436974893.542281, 1023545, 0.032214779406785965], [1436975373.638453, 1033759, 0.03219134360551834], [1436975853.524388, 1043955, 0.0324082113802433], [1436976333.625792, 1054148, 0.03221917897462845], [1436976813.610661, 1064342, 0.03200480341911316], [1436977293.601581, 1074539, 0.03198748826980591], [1436977773.575627, 1084733, 0.032064300030469894], [1436978253.564972, 1094914, 0.032298240810632706], [1436978733.673144, 1105109, 0.03248215466737747], [1436979213.540585, 1115293, 0.03262820467352867], [1436979693.699591, 1125483, 0.032745134085416794], [1436980173.613012, 1135670, 0.032681502401828766], [1436980653.575769, 1145862, 0.03240214288234711], [1436981133.719264, 1156045, 0.03237201273441315], [1436981613.563551, 1166236, 0.03202598914504051], [1436982093.553233, 1176436, 0.032163310796022415], [1436982573.577846, 1186636, 0.03232435882091522], [1436983053.605749, 1196837, 0.032410554587841034], [1436983533.684994, 1207025, 0.03245232254266739], [1436984013.561492, 1217233, 0.03224659338593483], [1436984493.629873, 1227437, 0.03204221650958061], [1436984973.606714, 1237643, 0.03186390548944473], [1436985453.690084, 1247835, 0.031911786645650864], [1436985933.711388, 1257951, 0.032286882400512695], [1436986413.598807, 1268125, 0.032266560941934586], [1436986893.631797, 1278290, 0.03252791985869408], [1436987373.596962, 1288473, 0.03241678699851036], [1436987853.555549, 1298650, 0.03210347890853882], [1436988333.722032, 1308841, 0.031904902309179306], [1436988813.55697, 1319018, 0.03179018944501877], [1436989293.756905, 1329221, 0.0316874124109745], [1436989773.665141, 1339417, 0.03160090371966362], [1436990253.768302, 1349610, 0.03161816671490669], [1436990733.708919, 1359759, 0.0317724235355854], [1436991213.663033, 1369914, 0.03175821527838707], [1436991693.730925, 1380074, 0.031629402190446854], [1436992173.751791, 1390224, 0.031547073274850845], [1436992653.758682, 1400383, 0.031528495252132416], [1436993133.835604, 1410542, 0.03169562667608261], [1436993613.674655, 1420684, 0.031826674938201904], [1436994093.747454, 1430832, 0.03185039013624191], [1436994573.768973, 1440986, 0.03199826925992966], [1436995053.666661, 1451174, 0.03156091645359993], [1436995533.83439, 1461345, 0.031506411731243134], [1436996013.556996, 1471495, 0.031495608389377594], [1436996493.635477, 1481663, 0.03134337440133095], [1436996973.668684, 1491822, 0.031145554035902023], [1436997453.59326, 1501979, 0.031068041920661926], [1436997933.774019, 1512139, 0.031432390213012695], [1436998413.575162, 1522290, 0.03142932057380676], [1436998893.640468, 1532431, 0.03132513165473938], [1436999373.551661, 1542579, 0.03125539794564247], [1436999853.57906, 1552734, 0.0309873279184103], [1437000333.680409, 1562888, 0.03088490664958954], [1437000813.602383, 1573037, 0.0308260228484869], [1437001293.610337, 1583190, 0.030793415382504463], [1437001773.618199, 1593341, 0.03087344579398632], [1437002253.572966, 1603497, 0.0308389812707901], [1437002733.67994, 1613657, 0.03070608340203762], [1437003213.583266, 1623809, 0.0307186096906662], [1437003693.639943, 1633966, 0.03048117645084858], [1437004173.568287, 1644113, 0.03030446544289589], [1437004653.610772, 1654268, 0.030324051156640053], [1437005133.663045, 1664424, 0.03043009154498577], [1437005613.580984, 1674567, 0.030361991375684738], [1437006093.601019, 1684715, 0.030566193163394928], [1437006573.625314, 1694857, 0.030430208891630173], [1437007053.584514, 1704999, 0.030224468559026718], [1437007533.719303, 1715150, 0.030241932719945908], [1437008013.604962, 1725282, 0.030097855255007744], [1437008493.655091, 1735432, 0.030217904597520828], [1437008973.640165, 1745584, 0.030181601643562317], [1437009453.715067, 1755742, 0.030172593891620636], [1437009933.765712, 1765896, 0.030141659080982208], [1437010413.632128, 1776052, 0.030052196234464645], [1437010893.66766, 1786195, 0.03007938154041767], [1437011373.636164, 1796346, 0.02996920794248581], [1437011853.631224, 1806481, 0.029995175078511238], [1437012333.706205, 1816617, 0.03010040894150734], [1437012813.61987, 1826754, 0.030088385567069054], [1437013293.479904, 1836883, 0.029996229335665703], [1437013773.604574, 1847029, 0.029950618743896484], [1437014253.618884, 1857175, 0.029801754280924797], [1437014733.756419, 1867312, 0.029606210067868233], [1437015213.638607, 1877459, 0.029520301148295403], [1437015693.625763, 1887608, 0.02937021106481552], [1437016173.63194, 1897759, 0.02928493171930313], [1437016653.609074, 1907909, 0.029194936156272888], [1437017133.717601, 1918074, 0.029153617098927498], [1437017613.716011, 1928220, 0.029063349589705467], [1437018093.626005, 1938377, 0.02899051643908024], [1437018573.626522, 1948523, 0.02908063493669033], [1437019053.648174, 1958678, 0.029026903212070465], [1437019533.803011, 1968831, 0.029071694239974022], [1437020013.667751, 1978978, 0.029110101982951164], [1437020493.659028, 1989133, 0.02908976934850216], [1437020973.657346, 1999287, 0.028982611373066902], [1437021453.650634, 2009437, 0.028793690726161003], [1437021933.848661, 2019588, 0.02868787571787834], [1437022413.674963, 2029736, 0.028585631400346756], [1437022893.69086, 2039894, 0.02848806604743004], [1437023373.68883, 2050054, 0.028294002637267113], [1437023853.686116, 2060205, 0.02822807803750038], [1437024333.763876, 2070362, 0.02809525839984417], [1437024813.707845, 2080507, 0.027941878885030746], [1437025293.483294, 2090645, 0.02787884697318077], [1437025773.695712, 2100793, 0.027862509712576866], [1437026253.672994, 2110943, 0.027835993096232414], [1437026733.780775, 2121094, 0.027756690979003906], [1437027213.617849, 2131235, 0.027644263580441475], [1437027693.694451, 2141382, 0.02752007730305195], [1437028173.68596, 2151537, 0.027432529255747795], [1437028653.584833, 2161685, 0.027434471994638443], [1437029133.792483, 2171839, 0.027317894622683525], [1437029613.661672, 2181977, 0.027138294652104378], [1437030093.641009, 2192118, 0.027088705450296402], [1437030573.656274, 2202268, 0.027131302282214165], [1437031053.643631, 2212416, 0.02715957537293434], [1437031533.777478, 2222583, 0.027089620009064674], [1437032013.704008, 2232736, 0.026989320293068886], [1437032493.638393, 2242882, 0.026922713965177536], [1437032973.684986, 2253041, 0.02678370475769043], [1437033453.699562, 2263183, 0.0267350971698761], [1437033933.918074, 2273320, 0.026652036234736443], [1437034413.596351, 2283443, 0.0265977680683136], [1437034893.640496, 2293579, 0.02654072269797325], [1437035373.637761, 2303701, 0.026471523568034172], [1437035853.669947, 2313823, 0.026451298967003822], [1437036333.78905, 2323961, 0.026429779827594757], [1437036813.699727, 2334089, 0.026324886828660965], [1437037293.662592, 2344235, 0.026287589222192764], [1437037773.66716, 2354364, 0.026264755055308342], [1437038253.603687, 2364507, 0.026225194334983826], [1437038733.78864, 2374644, 0.02624845691025257], [1437039213.641799, 2384782, 0.02624380588531494], [1437039693.687078, 2394923, 0.026255516335368156], [1437040173.635717, 2405058, 0.026186630129814148], [1437040653.673331, 2415194, 0.026059549301862717], [1437041133.764768, 2425322, 0.02603207901120186], [1437041613.629279, 2435449, 0.025951188057661057], [1437042093.703985, 2445575, 0.025885486975312233], [1437042573.496029, 2455712, 0.0258584376424551], [1437043053.686022, 2465844, 0.02576967515051365], [1437043533.731929, 2475974, 0.02574247308075428], [1437044013.636245, 2486095, 0.025741368532180786], [1437044493.69923, 2496238, 0.025613142177462578], [1437044973.652155, 2506373, 0.025502001866698265], [1437045453.691467, 2516497, 0.025422129780054092], [1437045933.935804, 2526637, 0.02530006691813469], [1437046413.635583, 2536770, 0.02533203549683094], [1437046893.626337, 2546896, 0.025261884555220604], [1437047373.67437, 2557029, 0.02517615258693695], [1437047853.652939, 2567169, 0.025054262951016426], [1437048333.778436, 2577306, 0.024978358298540115], [1437048813.654248, 2587433, 0.024952327832579613], [1437049293.610609, 2597552, 0.024846646934747696], [1437049773.646573, 2607690, 0.024763893336057663], [1437050253.667925, 2617808, 0.024688972160220146], [1437050733.735291, 2627933, 0.024599123746156693], [1437051213.620222, 2638053, 0.024585271254181862], [1437051693.601978, 2648171, 0.024474715813994408], [1437052173.634985, 2658299, 0.0243435837328434], [1437052653.687176, 2668425, 0.024294523522257805], [1437053133.762819, 2678556, 0.024163981899619102], [1437053613.643698, 2688671, 0.024034887552261353], [1437054093.673047, 2698804, 0.024000374600291252], [1437054573.667371, 2708956, 0.023914175108075142], [1437055053.650441, 2719087, 0.02395522966980934], [1437055533.778469, 2729219, 0.023859599605202675], [1437056013.694082, 2739343, 0.02375946193933487], [1437056493.674871, 2749458, 0.023677179589867592], [1437056973.700234, 2759575, 0.023645443841814995], [1437057453.666129, 2769697, 0.02356558106839657], [1437057933.848506, 2779821, 0.023514214903116226], [1437058413.643799, 2789941, 0.023489613085985184], [1437058893.715386, 2800076, 0.023429814726114273], [1437059373.62596, 2810207, 0.023343827575445175], [1437059853.650848, 2820334, 0.02322673238813877], [1437060333.792248, 2830465, 0.023134106770157814], [1437060813.682955, 2840600, 0.023055672645568848], [1437061293.681795, 2850745, 0.022985080257058144], [1437061773.691182, 2860880, 0.02291373908519745], [1437062253.662987, 2871013, 0.022864071652293205], [1437062733.760419, 2881153, 0.0227896086871624], [1437063213.651969, 2891278, 0.02276325598359108], [1437063693.723523, 2901406, 0.022676151245832443], [1437064173.68663, 2911533, 0.022622840479016304], [1437064653.547643, 2921667, 0.022551873698830605], [1437065133.62645, 2931813, 0.022431621327996254], [1437065613.566569, 2941947, 0.022368427366018295], [1437066093.537804, 2952102, 0.022323856130242348], [1437066573.529332, 2962243, 0.022268367931246758], [1437067053.520098, 2972400, 0.022210223600268364], [1437067533.605733, 2982561, 0.022118542343378067], [1437068013.535467, 2992698, 0.022013003006577492], [1437068493.559976, 3002839, 0.021971898153424263], [1437068973.558743, 3012983, 0.021911533549427986], [1437069453.562661, 3023116, 0.021851375699043274], [1437069933.627071, 3033256, 0.021762363612651825], [1437070413.574131, 3043386, 0.021733952686190605], [1437070893.658803, 3053528, 0.021669508889317513], [1437071373.638711, 3063659, 0.021594204008579254], [1437071853.621384, 3073794, 0.021531015634536743], [1437072333.665269, 3083926, 0.021499203518033028], [1437072813.584388, 3094040, 0.021456807851791382], [1437073293.569178, 3104172, 0.02136526256799698]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d4.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d4.json deleted file mode 100644 index 434b78cd0f5..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars/beta/d4.json +++ /dev/null @@ -1 +0,0 @@ -[[1436925978.257845, 7, 0.5028539896011353], [1436926413.945391, 1476, 0.4976981580257416], [1436926893.945037, 6006, 0.5092837810516357], [1436927373.995472, 13786, 0.5118998885154724], [1436927853.989794, 23650, 0.5314905643463135], [1436928334.132361, 33755, 0.550969123840332], [1436928813.973288, 43941, 0.5487659573554993], [1436929293.975949, 54146, 0.5263530015945435], [1436929773.992781, 64316, 0.5077286958694458], [1436930253.997415, 74465, 0.5120566487312317], [1436930734.203004, 84611, 0.5140185952186584], [1436931214.03644, 94700, 0.5133042335510254], [1436931694.094564, 104766, 0.5233010053634644], [1436932174.114955, 114817, 0.5230671763420105], [1436932654.161382, 124880, 0.5250263810157776], [1436933133.960214, 134977, 0.5088120698928833], [1436933614.044337, 145062, 0.5097426176071167], [1436934094.166206, 155169, 0.5103482007980347], [1436934574.106036, 165284, 0.5021579265594482], [1436935054.150647, 175402, 0.49785494804382324], [1436935533.819562, 185538, 0.4970649182796478], [1436936013.710422, 195712, 0.5023221373558044], [1436936493.609025, 205906, 0.5063169002532959], [1436936973.683892, 216099, 0.50455641746521], [1436937454.138383, 226331, 0.5104150772094727], [1436937933.838475, 236532, 0.5066487193107605], [1436938413.89688, 246724, 0.5183079838752747], [1436938894.018652, 256925, 0.5163102746009827], [1436939373.69067, 267137, 0.5216323733329773], [1436939853.673692, 277369, 0.5153006315231323], [1436940333.651346, 287620, 0.5240126252174377], [1436940813.599579, 297848, 0.5263218879699707], [1436941293.596313, 308088, 0.5236956477165222], [1436941773.659172, 318362, 0.534295916557312], [1436942253.648479, 328621, 0.540306031703949], [1436942733.752284, 338892, 0.5359382033348083], [1436943213.621881, 349144, 0.540198564529419], [1436943693.698743, 359399, 0.5404431819915771], [1436944173.578463, 369649, 0.5429667234420776], [1436944653.692217, 379912, 0.5415231585502625], [1436945133.677298, 390180, 0.54068922996521], [1436945613.572411, 400445, 0.5396349430084229], [1436946093.56123, 410703, 0.5486253499984741], [1436946573.542364, 420958, 0.5451043248176575], [1436947053.616578, 431216, 0.5478819608688354], [1436947533.636973, 441483, 0.5503379106521606], [1436948013.541574, 451751, 0.5534676313400269], [1436948493.560223, 462015, 0.5574610829353333], [1436948973.512541, 472260, 0.5558810234069824], [1436949453.550055, 482483, 0.5529404878616333], [1436949933.828011, 492731, 0.5618430972099304], [1436950413.603177, 502957, 0.5641138553619385], [1436950893.563009, 513185, 0.5707159638404846], [1436951373.620887, 523410, 0.5676558613777161], [1436951853.61941, 533618, 0.5637813806533813], [1436952333.694447, 543828, 0.5682924389839172], [1436952813.621004, 554042, 0.5690237283706665], [1436953293.588156, 564251, 0.5655006766319275], [1436953773.599734, 574464, 0.553955614566803], [1436954253.621309, 584672, 0.5558924674987793], [1436954733.738119, 594882, 0.5603042840957642], [1436955213.56617, 605091, 0.5625290870666504], [1436955693.585366, 615296, 0.5668522715568542], [1436956173.626395, 625501, 0.5736584663391113], [1436956653.601937, 635705, 0.5693879723548889], [1436957133.665878, 645915, 0.576599657535553], [1436957613.584762, 656116, 0.5648065805435181], [1436958093.549783, 666331, 0.5632508397102356], [1436958573.646778, 676543, 0.5660487413406372], [1436959053.585655, 686750, 0.568809449672699], [1436959533.679696, 696961, 0.5667826533317566], [1436960013.633292, 707173, 0.5637232661247253], [1436960493.578778, 717383, 0.5675314664840698], [1436960973.596715, 727598, 0.5714674592018127], [1436961453.625644, 737818, 0.564845085144043], [1436961933.740339, 748040, 0.5700833797454834], [1436962413.573845, 758252, 0.5702976584434509], [1436962893.610678, 768470, 0.5745863914489746], [1436963373.642878, 778674, 0.5763651728630066], [1436963853.558388, 788877, 0.5721960067749023], [1436964333.658419, 799099, 0.5714120864868164], [1436964813.573319, 809289, 0.5687000155448914], [1436965293.542098, 819484, 0.5728974938392639], [1436965773.545453, 829687, 0.5738612413406372], [1436966253.586517, 839901, 0.5702064037322998], [1436966733.639348, 850120, 0.5715107321739197], [1436967213.697288, 860330, 0.5695001482963562], [1436967693.617172, 870539, 0.5783872008323669], [1436968173.593885, 880748, 0.5758792161941528], [1436968653.560836, 890955, 0.572809636592865], [1436969133.676337, 901164, 0.5752230286598206], [1436969613.506638, 911358, 0.5861247181892395], [1436970093.595964, 921560, 0.5834078788757324], [1436970573.541227, 931756, 0.5814791321754456], [1436971053.624316, 941945, 0.5803619623184204], [1436971533.655543, 952138, 0.5765199065208435], [1436972013.604738, 962349, 0.5693190693855286], [1436972493.613199, 972551, 0.5720453262329102], [1436972973.501155, 982746, 0.5741620063781738], [1436973453.64842, 992945, 0.5705713629722595], [1436973933.689516, 1003147, 0.5657351613044739], [1436974413.577769, 1013350, 0.5685256123542786], [1436974893.542281, 1023545, 0.5698860287666321], [1436975373.638453, 1033759, 0.5801734328269958], [1436975853.524388, 1043955, 0.577880322933197], [1436976333.625792, 1054148, 0.5780594348907471], [1436976813.610661, 1064342, 0.5804633498191833], [1436977293.601581, 1074539, 0.5842364430427551], [1436977773.575627, 1084733, 0.5745837092399597], [1436978253.564972, 1094914, 0.5848771333694458], [1436978733.673144, 1105109, 0.5795935392379761], [1436979213.540585, 1115293, 0.583346426486969], [1436979693.699591, 1125483, 0.5840965509414673], [1436980173.613012, 1135670, 0.5807850360870361], [1436980653.575769, 1145862, 0.5843925476074219], [1436981133.719264, 1156045, 0.5828814506530762], [1436981613.563551, 1166236, 0.5873864889144897], [1436982093.553233, 1176436, 0.5896572470664978], [1436982573.577846, 1186636, 0.5887367725372314], [1436983053.605749, 1196837, 0.5841871500015259], [1436983533.684994, 1207025, 0.5867579579353333], [1436984013.561492, 1217233, 0.5940297842025757], [1436984493.629873, 1227437, 0.5925037860870361], [1436984973.606714, 1237643, 0.5981529951095581], [1436985453.690084, 1247835, 0.5954598188400269], [1436985933.711388, 1257951, 0.5903756022453308], [1436986413.598807, 1268125, 0.5837404131889343], [1436986893.631797, 1278290, 0.583182156085968], [1436987373.596962, 1288473, 0.5860618352890015], [1436987853.555549, 1298650, 0.5829544067382812], [1436988333.722032, 1308841, 0.5798720121383667], [1436988813.55697, 1319018, 0.589148998260498], [1436989293.756905, 1329221, 0.5905702710151672], [1436989773.665141, 1339417, 0.5900465250015259], [1436990253.768302, 1349610, 0.5893078446388245], [1436990733.708919, 1359759, 0.589722752571106], [1436991213.663033, 1369914, 0.5907371640205383], [1436991693.730925, 1380074, 0.5939858555793762], [1436992173.751791, 1390224, 0.5906378626823425], [1436992653.758682, 1400383, 0.5876493453979492], [1436993133.835604, 1410542, 0.5912420153617859], [1436993613.674655, 1420684, 0.5887293219566345], [1436994093.747454, 1430832, 0.589107096195221], [1436994573.768973, 1440986, 0.5928497910499573], [1436995053.666661, 1451174, 0.5916265845298767], [1436995533.83439, 1461345, 0.5911784768104553], [1436996013.556996, 1471495, 0.5890726447105408], [1436996493.635477, 1481663, 0.5914839506149292], [1436996973.668684, 1491822, 0.5915400385856628], [1436997453.59326, 1501979, 0.591564416885376], [1436997933.774019, 1512139, 0.5926578640937805], [1436998413.575162, 1522290, 0.5942149758338928], [1436998893.640468, 1532431, 0.5931802988052368], [1436999373.551661, 1542579, 0.587592601776123], [1436999853.57906, 1552734, 0.5877953171730042], [1437000333.680409, 1562888, 0.590681791305542], [1437000813.602383, 1573037, 0.5924896001815796], [1437001293.610337, 1583190, 0.5913501381874084], [1437001773.618199, 1593341, 0.5952408909797668], [1437002253.572966, 1603497, 0.5953922271728516], [1437002733.67994, 1613657, 0.6002237200737], [1437003213.583266, 1623809, 0.6042569875717163], [1437003693.639943, 1633966, 0.6017740368843079], [1437004173.568287, 1644113, 0.6037994623184204], [1437004653.610772, 1654268, 0.6037947535514832], [1437005133.663045, 1664424, 0.6028310060501099], [1437005613.580984, 1674567, 0.603211522102356], [1437006093.601019, 1684715, 0.6052727699279785], [1437006573.625314, 1694857, 0.6032628417015076], [1437007053.584514, 1704999, 0.5978461503982544], [1437007533.719303, 1715150, 0.602828323841095], [1437008013.604962, 1725282, 0.6063790917396545], [1437008493.655091, 1735432, 0.6047347784042358], [1437008973.640165, 1745584, 0.6031648516654968], [1437009453.715067, 1755742, 0.6067507863044739], [1437009933.765712, 1765896, 0.6062817573547363], [1437010413.632128, 1776052, 0.609245240688324], [1437010893.66766, 1786195, 0.6066284775733948], [1437011373.636164, 1796346, 0.6102170944213867], [1437011853.631224, 1806481, 0.609173595905304], [1437012333.706205, 1816617, 0.6035751104354858], [1437012813.61987, 1826754, 0.604059636592865], [1437013293.479904, 1836883, 0.6039224863052368], [1437013773.604574, 1847029, 0.5974730849266052], [1437014253.618884, 1857175, 0.6040806174278259], [1437014733.756419, 1867312, 0.6017186045646667], [1437015213.638607, 1877459, 0.5987159609794617], [1437015693.625763, 1887608, 0.6047909259796143], [1437016173.63194, 1897759, 0.6033824682235718], [1437016653.609074, 1907909, 0.6038352847099304], [1437017133.717601, 1918074, 0.6083348989486694], [1437017613.716011, 1928220, 0.6044996380805969], [1437018093.626005, 1938377, 0.6009799242019653], [1437018573.626522, 1948523, 0.60047847032547], [1437019053.648174, 1958678, 0.6019382476806641], [1437019533.803011, 1968831, 0.6007305383682251], [1437020013.667751, 1978978, 0.6025127172470093], [1437020493.659028, 1989133, 0.6051828861236572], [1437020973.657346, 1999287, 0.6085876822471619], [1437021453.650634, 2009437, 0.6065122485160828], [1437021933.848661, 2019588, 0.6084572076797485], [1437022413.674963, 2029736, 0.6065473556518555], [1437022893.69086, 2039894, 0.6075063347816467], [1437023373.68883, 2050054, 0.6095973253250122], [1437023853.686116, 2060205, 0.6047213077545166], [1437024333.763876, 2070362, 0.6034210324287415], [1437024813.707845, 2080507, 0.6008927822113037], [1437025293.483294, 2090645, 0.604469895362854], [1437025773.695712, 2100793, 0.6068717837333679], [1437026253.672994, 2110943, 0.6099737882614136], [1437026733.780775, 2121094, 0.6105009317398071], [1437027213.617849, 2131235, 0.611957311630249], [1437027693.694451, 2141382, 0.6141949892044067], [1437028173.68596, 2151537, 0.6135279536247253], [1437028653.584833, 2161685, 0.6111017465591431], [1437029133.792483, 2171839, 0.6135671138763428], [1437029613.661672, 2181977, 0.6112024188041687], [1437030093.641009, 2192118, 0.6097264289855957], [1437030573.656274, 2202268, 0.6097284555435181], [1437031053.643631, 2212416, 0.6121350526809692], [1437031533.777478, 2222583, 0.6147991418838501], [1437032013.704008, 2232736, 0.6118316054344177], [1437032493.638393, 2242882, 0.6191433072090149], [1437032973.684986, 2253041, 0.6188027262687683], [1437033453.699562, 2263183, 0.6163974404335022], [1437033933.918074, 2273320, 0.6144159436225891], [1437034413.596351, 2283443, 0.6123769879341125], [1437034893.640496, 2293579, 0.6139131188392639], [1437035373.637761, 2303701, 0.6150627136230469], [1437035853.669947, 2313823, 0.6149951219558716], [1437036333.78905, 2323961, 0.6155945658683777], [1437036813.699727, 2334089, 0.613308310508728], [1437037293.662592, 2344235, 0.6153736114501953], [1437037773.66716, 2354364, 0.6160987615585327], [1437038253.603687, 2364507, 0.611574113368988], [1437038733.78864, 2374644, 0.6145234107971191], [1437039213.641799, 2384782, 0.6117951273918152], [1437039693.687078, 2394923, 0.6129845380783081], [1437040173.635717, 2405058, 0.6095831394195557], [1437040653.673331, 2415194, 0.6110679507255554], [1437041133.764768, 2425322, 0.6099690198898315], [1437041613.629279, 2435449, 0.6105908155441284], [1437042093.703985, 2445575, 0.6124749779701233], [1437042573.496029, 2455712, 0.6118302345275879], [1437043053.686022, 2465844, 0.6094756722450256], [1437043533.731929, 2475974, 0.6094986796379089], [1437044013.636245, 2486095, 0.6114639639854431], [1437044493.69923, 2496238, 0.6101082563400269], [1437044973.652155, 2506373, 0.6105718612670898], [1437045453.691467, 2516497, 0.6115666627883911], [1437045933.935804, 2526637, 0.6128115653991699], [1437046413.635583, 2536770, 0.6122986078262329], [1437046893.626337, 2546896, 0.6142017245292664], [1437047373.67437, 2557029, 0.6111341714859009], [1437047853.652939, 2567169, 0.611350417137146], [1437048333.778436, 2577306, 0.6126709580421448], [1437048813.654248, 2587433, 0.6111524105072021], [1437049293.610609, 2597552, 0.6135894060134888], [1437049773.646573, 2607690, 0.6136029362678528], [1437050253.667925, 2617808, 0.6141685843467712], [1437050733.735291, 2627933, 0.6170881390571594], [1437051213.620222, 2638053, 0.6189730167388916], [1437051693.601978, 2648171, 0.6157540678977966], [1437052173.634985, 2658299, 0.6178646683692932], [1437052653.687176, 2668425, 0.6164441108703613], [1437053133.762819, 2678556, 0.6175132393836975], [1437053613.643698, 2688671, 0.6158696413040161], [1437054093.673047, 2698804, 0.6162974238395691], [1437054573.667371, 2708956, 0.6160892844200134], [1437055053.650441, 2719087, 0.6176281571388245], [1437055533.778469, 2729219, 0.6165231466293335], [1437056013.694082, 2739343, 0.6171510219573975], [1437056493.674871, 2749458, 0.6124134659767151], [1437056973.700234, 2759575, 0.6120688319206238], [1437057453.666129, 2769697, 0.6126770377159119], [1437057933.848506, 2779821, 0.6126595139503479], [1437058413.643799, 2789941, 0.616513729095459], [1437058893.715386, 2800076, 0.6130264401435852], [1437059373.62596, 2810207, 0.6114044785499573], [1437059853.650848, 2820334, 0.6077002882957458], [1437060333.792248, 2830465, 0.6086235046386719], [1437060813.682955, 2840600, 0.6084680557250977], [1437061293.681795, 2850745, 0.6094310879707336], [1437061773.691182, 2860880, 0.6066345572471619], [1437062253.662987, 2871013, 0.6094250082969666], [1437062733.760419, 2881153, 0.609106719493866], [1437063213.651969, 2891278, 0.6080747246742249], [1437063693.723523, 2901406, 0.6081057786941528], [1437064173.68663, 2911533, 0.6066460609436035], [1437064653.547643, 2921667, 0.6057829856872559], [1437065133.62645, 2931813, 0.6092885136604309], [1437065613.566569, 2941947, 0.6089289784431458], [1437066093.537804, 2952102, 0.6070758700370789], [1437066573.529332, 2962243, 0.6096142530441284], [1437067053.520098, 2972400, 0.609714925289154], [1437067533.605733, 2982561, 0.6116167306900024], [1437068013.535467, 2992698, 0.6119107007980347], [1437068493.559976, 3002839, 0.6119140386581421], [1437068973.558743, 3012983, 0.6115538477897644], [1437069453.562661, 3023116, 0.6126777529716492], [1437069933.627071, 3033256, 0.6146017909049988], [1437070413.574131, 3043386, 0.6119789481163025], [1437070893.658803, 3053528, 0.6139205694198608], [1437071373.638711, 3063659, 0.612362802028656], [1437071853.621384, 3073794, 0.6109192371368408], [1437072333.665269, 3083926, 0.6141091585159302], [1437072813.584388, 3094040, 0.6132751703262329], [1437073293.569178, 3104172, 0.6132386922836304]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json new file mode 100644 index 00000000000..6d584fb4a9e --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json @@ -0,0 +1 @@ +[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json new file mode 100644 index 00000000000..025eaa16e93 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json @@ -0,0 +1 @@ +[[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json new file mode 100644 index 00000000000..eae69dd78f3 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json @@ -0,0 +1 @@ +[[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json new file mode 100644 index 00000000000..6d584fb4a9e --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json @@ -0,0 +1 @@ +[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json new file mode 100644 index 00000000000..6d584fb4a9e --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json @@ -0,0 +1 @@ +[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json new file mode 100644 index 00000000000..dd3593f9d10 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json @@ -0,0 +1 @@ +[[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json new file mode 100644 index 00000000000..0ff9ef0551d --- /dev/null +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json @@ -0,0 +1 @@ +[[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html index 02646c6c180..78f657b4104 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html @@ -16,33 +16,55 @@ See the License for the specific language governing permissions and limitations under the License. --> - - - - - - Event Dashboard Demo Demo - - - + + + + + + + +Scalar Dashboard Demo + + + + diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html b/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html index d4688bb7c48..848ed5292de 100644 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html +++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html @@ -23,6 +23,7 @@ limitations under the License. + @@ -58,54 +59,64 @@ contains vz-line-charts embedded inside tf-panes-helper's. @@ -237,4 +237,5 @@ limitations under the License. + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.ts index 20dc67167f9..3ee2c2165f2 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.ts @@ -17,6 +17,7 @@ import {DistanceFunction, SpriteAndMetadataInfo, State} from './data'; import * as knn from './knn'; import {ProjectorEventContext} from './projectorEventContext'; import * as adapter from './projectorScatterPlotAdapter'; +import * as util from './util'; import * as vector from './vector'; import {Projector} from './vz-projector'; import {ProjectorInput} from './vz-projector-input'; @@ -40,23 +41,24 @@ export class InspectorPanel extends PolymerClass { private selectedMetadataField: string; private metadataFields: string[]; - private dom: d3.Selection; private projector: Projector; private selectedPointIndices: number[]; private neighborsOfFirstPoint: knn.NearestEntry[]; private searchBox: ProjectorInput; - private resetFilterButton: d3.Selection; - private setFilterButton: d3.Selection; - private clearSelectionButton: d3.Selection; - private limitMessage: d3.Selection; + private resetFilterButton: HTMLButtonElement; + private setFilterButton: HTMLButtonElement; + private clearSelectionButton: HTMLButtonElement; + private limitMessage: HTMLDivElement; ready() { - this.dom = d3.select(this); - this.resetFilterButton = this.dom.select('.reset-filter'); - this.setFilterButton = this.dom.select('.set-filter'); - this.clearSelectionButton = this.dom.select('.clear-selection'); - this.limitMessage = this.dom.select('.limit-msg'); + this.resetFilterButton = + this.querySelector('.reset-filter') as HTMLButtonElement; + this.setFilterButton = + this.querySelector('.set-filter') as HTMLButtonElement; + this.clearSelectionButton = + this.querySelector('.clear-selection') as HTMLButtonElement; + this.limitMessage = this.querySelector('.limit-msg') as HTMLDivElement; this.searchBox = this.querySelector('#search-box') as ProjectorInput; // https://www.polymer-project.org/1.0/docs/devguide/styling#scope-subtree this.scopeSubtree(this, true); @@ -88,7 +90,7 @@ export class InspectorPanel extends PolymerClass { } private enableResetFilterButton(enabled: boolean) { - this.resetFilterButton.attr('disabled', enabled ? null : true); + this.resetFilterButton.disabled = !enabled; } restoreUIFromBookmark(bookmark: State) { @@ -113,143 +115,178 @@ export class InspectorPanel extends PolymerClass { } private updateSearchResults(indices: number[]) { - let container = this.dom.select('.matches-list'); - container.style('display', indices.length ? null : 'none'); - let list = container.select('.list'); - list.html(''); + const container = this.querySelector('.matches-list') as HTMLDivElement; + container.style.display = indices.length ? null : 'none'; + const list = container.querySelector('.list') as HTMLDivElement; + list.innerHTML = ''; if (indices.length === 0) { return; } - this.limitMessage.style( - 'display', indices.length <= LIMIT_RESULTS ? 'none' : null); + + this.limitMessage.style.display = + indices.length <= LIMIT_RESULTS ? 'none' : null; indices = indices.slice(0, LIMIT_RESULTS); - let rows = list.selectAll('.row').data(indices).enter().append('div').attr( - 'class', 'row'); - rows.append('a') - .attr('class', 'label') - .attr('title', index => this.getLabelFromIndex(index)) - .text(index => this.getLabelFromIndex(index)); - rows.on('mouseenter', index => { - this.projectorEventContext.notifyHoverOverPoint(index); - }); - rows.on('mouseleave', () => { - this.projectorEventContext.notifyHoverOverPoint(null); - }); - rows.on('click', index => { - this.projectorEventContext.notifySelectionChanged([index]); - }); + + for (let i = 0; i < indices.length; i++) { + const index = indices[i]; + + const row = document.createElement('div'); + row.className = 'row'; + + const label = this.getLabelFromIndex(index); + const rowLink = document.createElement('a'); + rowLink.className = 'label'; + rowLink.title = label; + rowLink.innerText = label; + + rowLink.onmouseenter = () => { + this.projectorEventContext.notifyHoverOverPoint(index); + }; + rowLink.onmouseleave = () => { + this.projectorEventContext.notifyHoverOverPoint(null); + }; + rowLink.onclick = () => { + this.projectorEventContext.notifySelectionChanged([index]); + }; + + row.appendChild(rowLink); + list.appendChild(row); + } } private getLabelFromIndex(pointIndex: number): string { - let point = this.projector.dataSet.points[pointIndex]; + const point = this.projector.dataSet.points[pointIndex]; return point.metadata[this.selectedMetadataField].toString(); } private updateNeighborsList(neighbors: knn.NearestEntry[]) { - let nnlist = this.dom.select('.nn-list'); - nnlist.html(''); - this.dom.select('.nn').style('display', neighbors.length ? null : 'none'); + const nnlist = this.querySelector('.nn-list') as HTMLDivElement; + nnlist.innerHTML = ''; + + (this.querySelector('.nn') as HTMLDivElement).style.display = + neighbors.length ? null : 'none'; if (neighbors.length === 0) { return; } this.searchBox.message = ''; - let minDist = neighbors.length > 0 ? neighbors[0].dist : 0; - let n = nnlist.selectAll('.neighbor') - .data(neighbors) - .enter() - .append('div') - .attr('class', 'neighbor') - .append('a') - .attr('class', 'neighbor-link') - .attr('title', d => this.getLabelFromIndex(d.index)); + const minDist = neighbors.length > 0 ? neighbors[0].dist : 0; + for (let i = 0; i < neighbors.length; i++) { + const neighbor = neighbors[i]; - let labelValue = n.append('div').attr('class', 'label-and-value'); - labelValue.append('div') - .attr('class', 'label') - .style('color', d => adapter.dist2color(this.distFunc, d.dist, minDist)) - .text(d => this.getLabelFromIndex(d.index)); + const neighborElement = document.createElement('div'); + neighborElement.className = 'neighbor'; - labelValue.append('div') - .attr('class', 'value') - .text(d => d.dist.toFixed(3)); + const neighborElementLink = document.createElement('a'); + neighborElementLink.className = 'neighbor-link'; + neighborElementLink.title = this.getLabelFromIndex(neighbor.index); - let bar = n.append('div').attr('class', 'bar'); + const labelValueElement = document.createElement('div'); + labelValueElement.className = 'label-and-value'; - bar.append('div') - .attr('class', 'fill') - .style( - 'border-top-color', - d => { - return adapter.dist2color(this.distFunc, d.dist, minDist); - }) - .style( - 'width', - d => adapter.normalizeDist(this.distFunc, d.dist, minDist) * 100 + - '%'); + const labelElement = document.createElement('div'); + labelElement.className = 'label'; + labelElement.style.color = + adapter.dist2color(this.distFunc, neighbor.dist, minDist); + labelElement.innerText = this.getLabelFromIndex(neighbor.index); - bar.selectAll('.tick') - .data(d3.range(1, 4)) - .enter() - .append('div') - .attr('class', 'tick') - .style('left', d => d * 100 / 4 + '%'); - n.on('mouseenter', d => { - this.projectorEventContext.notifyHoverOverPoint(d.index); - }); - n.on('mouseleave', () => { - this.projectorEventContext.notifyHoverOverPoint(null); - }); - n.on('click', d => { - this.projectorEventContext.notifySelectionChanged([d.index]); - }); + const valueElement = document.createElement('div'); + valueElement.className = 'value'; + valueElement.innerText = neighbor.dist.toFixed(3); + + labelValueElement.appendChild(labelElement); + labelValueElement.appendChild(valueElement); + + const barElement = document.createElement('div'); + barElement.className = 'bar'; + + const barFillElement = document.createElement('div'); + barFillElement.className = 'fill'; + barFillElement.style.borderTopColor = + adapter.dist2color(this.distFunc, neighbor.dist, minDist); + barFillElement.style.width = + adapter.normalizeDist(this.distFunc, neighbor.dist, minDist) * 100 + + '%'; + barElement.appendChild(barFillElement); + + for (let j = 1; j < 4; j++) { + const tickElement = document.createElement('div'); + tickElement.className = 'tick'; + tickElement.style.left = j * 100 / 4 + '%'; + barElement.appendChild(tickElement); + } + + neighborElementLink.appendChild(labelValueElement); + neighborElementLink.appendChild(barElement); + neighborElement.appendChild(neighborElementLink); + nnlist.appendChild(neighborElement); + + neighborElementLink.onmouseenter = () => { + this.projectorEventContext.notifyHoverOverPoint(neighbor.index); + }; + neighborElementLink.onmouseleave = () => { + this.projectorEventContext.notifyHoverOverPoint(null); + }; + neighborElementLink.onclick = () => { + this.projectorEventContext.notifySelectionChanged([neighbor.index]); + }; + } } private updateFilterButtons(numPoints: number) { if (numPoints > 1) { - this.setFilterButton.text(`Isolate ${numPoints} points`) - .attr('disabled', null); - this.clearSelectionButton.attr('disabled', null); + this.setFilterButton.innerText = `Isolate ${numPoints} points`; + this.setFilterButton.disabled = null; + this.clearSelectionButton.disabled = null; } else { - this.setFilterButton.attr('disabled', true); - this.clearSelectionButton.attr('disabled', true); + this.setFilterButton.disabled = true; + this.clearSelectionButton.disabled = true; } } private setupUI(projector: Projector) { this.distFunc = vector.cosDist; - let eucDist = this.dom.select('.distance a.euclidean'); - eucDist.on('click', () => { - this.dom.selectAll('.distance a').classed('selected', false); - eucDist.classed('selected', true); + const eucDist = + this.querySelector('.distance a.euclidean') as HTMLLinkElement; + eucDist.onclick = () => { + const links = this.querySelectorAll('.distance a'); + for (let i = 0; i < links.length; i++) { + util.classed(links[i] as HTMLElement, 'selected', false); + } + util.classed(eucDist as HTMLElement, 'selected', true); + this.distFunc = vector.dist; this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); - let neighbors = projector.dataSet.findNeighbors( + const neighbors = projector.dataSet.findNeighbors( this.selectedPointIndices[0], this.distFunc, this.numNN); this.updateNeighborsList(neighbors); - }); + }; + + const cosDist = this.querySelector('.distance a.cosine') as HTMLLinkElement; + cosDist.onclick = () => { + const links = this.querySelectorAll('.distance a'); + for (let i = 0; i < links.length; i++) { + util.classed(links[i] as HTMLElement, 'selected', false); + } + util.classed(cosDist, 'selected', true); - let cosDist = this.dom.select('.distance a.cosine'); - cosDist.on('click', () => { - this.dom.selectAll('.distance a').classed('selected', false); - cosDist.classed('selected', true); this.distFunc = vector.cosDist; this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); - let neighbors = projector.dataSet.findNeighbors( + const neighbors = projector.dataSet.findNeighbors( this.selectedPointIndices[0], this.distFunc, this.numNN); this.updateNeighborsList(neighbors); - }); + }; // Called whenever the search text input changes. - let updateInput = (value: string, inRegexMode: boolean) => { + const updateInput = (value: string, inRegexMode: boolean) => { if (value == null || value.trim() === '') { this.searchBox.message = ''; this.projectorEventContext.notifySelectionChanged([]); return; } - let indices = projector.dataSet.query( + const indices = projector.dataSet.query( value, inRegexMode, this.selectedMetadataField); if (indices.length === 0) { this.searchBox.message = '0 matches.'; @@ -263,10 +300,11 @@ export class InspectorPanel extends PolymerClass { }); // Nearest neighbors controls. - let numNNInput = this.$$('#nn-slider') as HTMLInputElement; - let updateNumNN = () => { + const numNNInput = this.$$('#nn-slider') as HTMLInputElement; + const updateNumNN = () => { this.numNN = +numNNInput.value; - this.dom.select('.num-nn .nn-count').text(this.numNN); + (this.querySelector('.num-nn .nn-count') as HTMLSpanElement).innerText = + '' + this.numNN; if (this.selectedPointIndices != null) { this.projectorEventContext.notifySelectionChanged( [this.selectedPointIndices[0]]); @@ -276,22 +314,22 @@ export class InspectorPanel extends PolymerClass { updateNumNN(); // Filtering dataset. - this.setFilterButton.on('click', () => { + this.setFilterButton.onclick = () => { const indices = this.selectedPointIndices.concat( this.neighborsOfFirstPoint.map(n => n.index)); projector.filterDataset(indices); this.enableResetFilterButton(true); this.updateFilterButtons(0); - }); + }; - this.resetFilterButton.on('click', () => { + this.resetFilterButton.onclick = () => { projector.resetFilterDataset(); this.enableResetFilterButton(false); - }); + }; - this.clearSelectionButton.on('click', () => { + this.clearSelectionButton.onclick = () => { projector.adjustSelectionAndHover([]); - }); + }; this.enableResetFilterButton(false); } } diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html index 3fc5f4db158..4b98d8bded8 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html @@ -17,6 +17,7 @@ limitations under the License. + - \ No newline at end of file + +
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.ts index d30a9554805..1c4ddf940dc 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.ts @@ -44,11 +44,6 @@ export interface ColorLegendThreshold { export class Legend extends LegendPolymer { renderInfo: ColorLegendRenderInfo; - dom: d3.Selection; - - ready() { - this.dom = d3.select(this); - } _renderInfoChanged() { if (this.renderInfo == null) { @@ -70,29 +65,32 @@ export class Legend extends LegendPolymer { } private getOffset(value: number): string { - let min = this.renderInfo.thresholds[0].value; - let max = + const min = this.renderInfo.thresholds[0].value; + const max = this.renderInfo.thresholds[this.renderInfo.thresholds.length - 1].value; return (100 * (value - min) / (max - min)).toFixed(2) + '%'; } private setupLinearGradient() { - let linearGradient = this.dom.select('#gradient'); + const linearGradient = + this.querySelector('#gradient') as SVGLinearGradientElement; - let width = - (this.dom.select('svg.gradient').node() as SVGElement).clientWidth; + const width = + (this.querySelector('svg.gradient') as SVGElement).clientWidth; // Set the svg to be the width of its parent. - this.dom.select('svg.gradient rect').attr('width', width); + (this.querySelector('svg.gradient rect') as SVGRectElement).style.width = + width + 'px'; // Remove all children from before. - linearGradient.selectAll('*').remove(); + linearGradient.innerHTML = ''; // Add a child in for each gradient threshold. this.renderInfo.thresholds.forEach(t => { - linearGradient.append('stop') - .attr('offset', this.getOffset(t.value)) - .attr('stop-color', t.color); + const stopElement = + document.createElementNS('http://www.w3.org/2000/svg', 'stop'); + stopElement.setAttribute('offset', this.getOffset(t.value)); + stopElement.setAttribute('stop-color', t.color); }); } } diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html index ebdcd72c77d..4231a61ff30 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html @@ -18,6 +18,7 @@ limitations under the License. + + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts index 17a4700bb5c..939300f3878 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts @@ -28,8 +28,6 @@ export let MetadataCardPolymer = PolymerElement({ }); export class MetadataCard extends MetadataCardPolymer { - private dom: d3.Selection; - hasMetadata: boolean; metadata: Array<{key: string, value: string}>; label: string; @@ -37,22 +35,28 @@ export class MetadataCard extends MetadataCardPolymer { private labelOption: string; private pointMetadata: PointMetadata; - ready() { - this.dom = d3.select(this); - } + private expandLessButton: HTMLButtonElement; + private expandMoreButton: HTMLButtonElement; + ready() { + this.expandLessButton = + this.querySelector('#expand-less') as HTMLButtonElement; + this.expandMoreButton = + this.querySelector('#expand-more') as HTMLButtonElement; + } /** Handles a click on the expand more icon. */ _expandMore() { (this.$$('#metadata-container') as any).toggle(); - this.dom.select('#expand-more').style('display', 'none'); - this.dom.select('#expand-less').style('display', ''); + + this.expandMoreButton.style.display = 'none'; + this.expandLessButton.style.display = ''; } /** Handles a click on the expand less icon. */ _expandLess() { (this.$$('#metadata-container') as any).toggle(); - this.dom.select('#expand-more').style('display', ''); - this.dom.select('#expand-less').style('display', 'none'); + this.expandMoreButton.style.display = ''; + this.expandLessButton.style.display = 'none'; } updateMetadata(pointMetadata?: PointMetadata) { diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html index cddcb2b7d08..b82f3f520b5 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html @@ -30,6 +30,7 @@ limitations under the License. + + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts index 9df182ed489..377c6c11ad5 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts @@ -15,6 +15,7 @@ limitations under the License. import * as data from './data'; import {DataSet, Projection, ProjectionType, SpriteAndMetadataInfo, State} from './data'; +import * as util from './util'; import * as vector from './vector'; import {Vector} from './vector'; import {Projector} from './vz-projector'; @@ -92,13 +93,12 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { public customSelectedSearchByMetadataOption: string; /** Polymer elements. */ - private dom: d3.Selection; - private runTsneButton: d3.Selection; - private stopTsneButton: d3.Selection; + private runTsneButton: HTMLButtonElement; + private stopTsneButton: HTMLButtonElement; private perplexitySlider: HTMLInputElement; private learningRateInput: HTMLInputElement; - private zDropdown: d3.Selection; - private iterationLabel: d3.Selection; + private zDropdown: HTMLElement; + private iterationLabel: HTMLElement; private customProjectionXLeftInput: ProjectorInput; private customProjectionXRightInput: ProjectorInput; @@ -121,14 +121,14 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { } ready() { - this.dom = d3.select(this); - this.zDropdown = this.dom.select('#z-dropdown'); - this.runTsneButton = this.dom.select('.run-tsne'); - this.stopTsneButton = this.dom.select('.stop-tsne'); - this.perplexitySlider = this.$$('#perplexity-slider') as HTMLInputElement; + this.zDropdown = this.querySelector('#z-dropdown') as HTMLElement; + this.runTsneButton = this.querySelector('.run-tsne') as HTMLButtonElement; + this.stopTsneButton = this.querySelector('.stop-tsne') as HTMLButtonElement; + this.perplexitySlider = + this.querySelector('#perplexity-slider') as HTMLInputElement; this.learningRateInput = - this.$$('#learning-rate-slider') as HTMLInputElement; - this.iterationLabel = this.dom.select('.run-tsne-iter'); + this.querySelector('#learning-rate-slider') as HTMLInputElement; + this.iterationLabel = this.querySelector('.run-tsne-iter') as HTMLElement; } disablePolymerChangesTriggerReprojection() { @@ -143,27 +143,33 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { if (this.perplexitySlider) { this.perplexity = +this.perplexitySlider.value; } - this.dom.select('.tsne-perplexity span').text(this.perplexity); + (this.querySelector('.tsne-perplexity span') as HTMLSpanElement).innerText = + '' + this.perplexity; } private updateTSNELearningRateFromUIChange() { if (this.learningRateInput) { this.learningRate = Math.pow(10, +this.learningRateInput.value); } - this.dom.select('.tsne-learning-rate span').text(this.learningRate); + (this.querySelector('.tsne-learning-rate span') as HTMLSpanElement) + .innerText = '' + this.learningRate; } private setupUIControls() { { const self = this; - this.dom.selectAll('.ink-tab').on('click', function() { - let id = this.getAttribute('data-tab'); - self.showTab(id); - }); + const inkTabs = this.querySelectorAll('.ink-tab'); + for (let i = 0; i < inkTabs.length; i++) { + inkTabs[i].addEventListener('click', function() { + let id = this.getAttribute('data-tab'); + self.showTab(id); + }); + } } - this.runTsneButton.on('click', () => this.runTSNE()); - this.stopTsneButton.on('click', () => this.dataSet.stopTSNE()); + this.runTsneButton.addEventListener('click', () => this.runTSNE()); + this.stopTsneButton.addEventListener( + 'click', () => this.dataSet.stopTSNE()); this.perplexitySlider.value = this.perplexity.toString(); this.perplexitySlider.addEventListener( @@ -177,8 +183,11 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { this.setupCustomProjectionInputFields(); // TODO: figure out why `--paper-input-container-input` css mixin didn't // work. - this.dom.selectAll('paper-dropdown-menu paper-input input') - .style('font-size', '14px'); + const inputs = + this.querySelectorAll('paper-dropdown-menu paper-input input'); + for (let i = 0; i < inputs.length; i++) { + (inputs[i] as HTMLElement).style.fontSize = '14px'; + } } restoreUIFromBookmark(bookmark: State) { @@ -226,9 +235,11 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { this.updateTSNEPerplexityFromSliderChange(); this.updateTSNELearningRateFromUIChange(); if (this.iterationLabel) { - this.iterationLabel.text(bookmark.tSNEIteration.toString()); + this.iterationLabel.innerText = bookmark.tSNEIteration.toString(); + } + if (bookmark.selectedProjection != null) { + this.showTab(bookmark.selectedProjection); } - this.showTab(bookmark.selectedProjection); this.enablePolymerChangesTriggerReprojection(); } @@ -282,7 +293,11 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { // and the DOM. setZDropdownEnabled(enabled: boolean) { if (this.zDropdown) { - this.zDropdown.attr('disabled', enabled ? null : true); + if (enabled) { + this.zDropdown.removeAttribute('disabled'); + } else { + this.zDropdown.setAttribute('disabled', 'true'); + } } } @@ -296,13 +311,13 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { this.updateTSNEPerplexityFromSliderChange(); this.clearCentroids(); - this.dom.select('#tsne-sampling') - .style('display', pointCount > data.TSNE_SAMPLE_SIZE ? null : 'none'); + (this.querySelector('#tsne-sampling') as HTMLElement).style.display = + pointCount > data.TSNE_SAMPLE_SIZE ? null : 'none'; const wasSampled = (dataSet == null) ? false : (dataSet.dim[0] > data.PCA_SAMPLE_DIM || dataSet.dim[1] > data.PCA_SAMPLE_DIM); - this.dom.select('#pca-sampling') - .style('display', wasSampled ? null : 'none'); + (this.querySelector('#pca-sampling') as HTMLElement).style.display = + wasSampled ? null : 'none'; this.showTab('pca'); } @@ -332,12 +347,24 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { public showTab(id: ProjectionType) { this.currentProjection = id; - let tab = this.dom.select('.ink-tab[data-tab="' + id + '"]'); - this.dom.selectAll('.ink-tab').classed('active', false); - tab.classed('active', true); - this.dom.selectAll('.ink-panel-content').classed('active', false); - this.dom.select('.ink-panel-content[data-panel="' + id + '"]') - .classed('active', true); + const tab = + this.querySelector('.ink-tab[data-tab="' + id + '"]') as HTMLElement; + const allTabs = this.querySelectorAll('.ink-tab'); + for (let i = 0; i < allTabs.length; i++) { + util.classed(allTabs[i] as HTMLElement, 'active', false); + } + + util.classed(tab, 'active', true); + + const allTabContent = this.querySelectorAll('.ink-panel-content'); + for (let i = 0; i < allTabContent.length; i++) { + util.classed(allTabContent[i] as HTMLElement, 'active', false); + } + + util.classed( + this.querySelector('.ink-panel-content[data-panel="' + id + '"]') as + HTMLElement, + 'active', true); // guard for unit tests, where polymer isn't attached and $ doesn't exist. if (this.$ != null) { @@ -392,17 +419,17 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { } private runTSNE() { - this.runTsneButton.attr('disabled', true); - this.stopTsneButton.attr('disabled', null); + this.runTsneButton.disabled = true; + this.stopTsneButton.disabled = null; this.dataSet.projectTSNE( this.perplexity, this.learningRate, this.tSNEis3d ? 3 : 2, (iteration: number) => { if (iteration != null) { - this.iterationLabel.text(iteration); + this.iterationLabel.innerText = '' + iteration; this.projector.notifyProjectionPositionsUpdated(); } else { - this.runTsneButton.attr('disabled', null); - this.stopTsneButton.attr('disabled', true); + this.runTsneButton.disabled = null; + this.stopTsneButton.disabled = true; } }); } @@ -422,7 +449,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { totalVariance += variances[this.pcaZ]; } msg += (totalVariance * 100).toFixed(1) + '%.'; - this.dom.select('#total-variance').html(msg); + (this.querySelector('#total-variance') as HTMLElement).innerHTML = msg; } private showPCA() { @@ -440,7 +467,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { this.projector.setProjection(projection); let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]); this.updateTotalVarianceMessage(); - this.pcaComponents = d3.range(0, numComponents).map(i => { + this.pcaComponents = util.range(numComponents).map(i => { let fracVariance = this.dataSet.fracVariancesExplained[i]; return { id: i, diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel_test.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel_test.ts deleted file mode 100644 index 3ce35afb743..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel_test.ts +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -import {State} from './data'; -import {ProjectionsPanel} from './vz-projector-projections-panel'; - -const assert = chai.assert; - -describe('restoreUIFromBookmark', () => { - it('sets the pcaX/Y properties when setting 2D component values', () => { - let projectionsPanel = document.createElement( - ProjectionsPanel.prototype.is) as ProjectionsPanel; - - spyOn(projectionsPanel, 'setZDropdownEnabled'); - - const s = new State(); - s.pcaComponentDimensions = [0, 1]; - projectionsPanel.restoreUIFromBookmark(s); - - assert.equal(0, projectionsPanel.pcaX); - assert.equal(1, projectionsPanel.pcaY); - - expect(projectionsPanel.setZDropdownEnabled).toHaveBeenCalledWith(false); - }); - - it('sets the pcaX/Y properties when setting 3D component values', () => { - let projectionsPanel = document.createElement( - ProjectionsPanel.prototype.is) as ProjectionsPanel; - - spyOn(projectionsPanel, 'setZDropdownEnabled'); - - const s = new State(); - s.pcaComponentDimensions = [0, 1, 2]; - projectionsPanel.restoreUIFromBookmark(s); - - assert.equal(0, projectionsPanel.pcaX); - assert.equal(1, projectionsPanel.pcaY); - assert.equal(2, projectionsPanel.pcaZ); - - expect(projectionsPanel.setZDropdownEnabled).toHaveBeenCalledWith(true); - }); -}); - -describe('populateBookmarkFromUI', () => { - it('gets the PCA component UI values from a 2D PCA projection', () => { - let projectionsPanel = document.createElement( - ProjectionsPanel.prototype.is) as ProjectionsPanel; - - projectionsPanel.pcaX = 0; - projectionsPanel.pcaY = 1; - projectionsPanel.pcaIs3d = false; - - const s = new State(); - projectionsPanel.populateBookmarkFromUI(s); - assert.deepEqual([0, 1], s.pcaComponentDimensions); - }); - - it('gets the PCA component UI values from a 3D PCA projection', () => { - let projectionsPanel = document.createElement( - ProjectionsPanel.prototype.is) as ProjectionsPanel; - - projectionsPanel.pcaX = 0; - projectionsPanel.pcaY = 1; - projectionsPanel.pcaZ = 2; - projectionsPanel.pcaIs3d = true; - - const s = new State(); - projectionsPanel.populateBookmarkFromUI(s); - assert.deepEqual([0, 1, 2], s.pcaComponentDimensions); - }); -}); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.html b/tensorflow/tensorboard/components/vz_projector/vz-projector.html index d4be2f26a5d..438ea9f4e97 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.html +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector.html @@ -32,6 +32,7 @@ limitations under the License. + @@ -40,6 +41,7 @@ limitations under the License. + + diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts index ba0f669e56f..bf98a4d4785 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts @@ -70,7 +70,6 @@ export class Projector extends ProjectorPolymer implements private originalDataSet: DataSet; private dataSetBeforeFilter: DataSet; - private dom: d3.Selection; private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter; private dim: number; @@ -94,13 +93,12 @@ export class Projector extends ProjectorPolymer implements private projectionsPanel: ProjectionsPanel; private metadataCard: MetadataCard; - private statusBar: d3.Selection; + private statusBar: HTMLDivElement; private analyticsLogger: AnalyticsLogger; private eventLogging: boolean; private pageViewLogging: boolean; ready() { - this.dom = d3.select(this); logging.setDomContainer(this); this.analyticsLogger = @@ -130,7 +128,7 @@ export class Projector extends ProjectorPolymer implements this.bookmarkPanel = this.$['bookmark-panel'] as BookmarkPanel; this.bookmarkPanel.initialize(this, this as ProjectorEventContext); this.metadataCard = this.$['metadata-card'] as MetadataCard; - this.statusBar = this.dom.select('#status-bar'); + this.statusBar = this.querySelector('#status-bar') as HTMLDivElement; this.scopeSubtree(this.$$('#notification-dialog'), true); this.setupUIControls(); this.initializeDataProvider(); @@ -199,8 +197,8 @@ export class Projector extends ProjectorPolymer implements this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile); // Set the container to a fixed height, otherwise in Colab the // height can grow indefinitely. - let container = this.dom.select('#container'); - container.style('height', container.property('clientHeight') + 'px'); + const container = this.querySelector('#container') as HTMLDivElement; + container.style.height = container.clientHeight + 'px'; } else { this.setCurrentDataSet(null); } @@ -226,7 +224,7 @@ export class Projector extends ProjectorPolymer implements this.dataSetFilterIndices = pointIndices; this.projectorScatterPlotAdapter.updateScatterPlotPositions(); this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.adjustSelectionAndHover(d3.range(selectionSize)); + this.adjustSelectionAndHover(util.range(selectionSize)); } resetFilterDataset() { @@ -387,8 +385,10 @@ export class Projector extends ProjectorPolymer implements ds.normalize(); } this.dim = (ds == null) ? 0 : ds.dim[1]; - this.dom.select('span.numDataPoints').text((ds == null) ? '0' : ds.dim[0]); - this.dom.select('span.dim').text((ds == null) ? '0' : ds.dim[1]); + (this.querySelector('span.numDataPoints') as HTMLSpanElement).innerText = + (ds == null) ? '0' : '' + ds.dim[0]; + (this.querySelector('span.dim') as HTMLSpanElement).innerText = + (ds == null) ? '0' : '' + ds.dim[1]; this.dataSet = ds; @@ -425,10 +425,9 @@ export class Projector extends ProjectorPolymer implements }); window.addEventListener('resize', () => { - let container = this.dom.select('#container'); - let parentHeight = - (container.node().parentNode as HTMLElement).clientHeight; - container.style('height', parentHeight + 'px'); + const container = this.querySelector('#container') as HTMLDivElement; + const parentHeight = (container.parentNode as HTMLElement).clientHeight; + container.style.height = parentHeight + 'px'; this.projectorScatterPlotAdapter.resize(); }); @@ -463,13 +462,13 @@ export class Projector extends ProjectorPolymer implements } } if (this.selectedPointIndices.length === 0) { - this.statusBar.style('display', hoverText ? null : 'none'); - this.statusBar.text(hoverText); + this.statusBar.style.display = hoverText ? null : 'none'; + this.statusBar.innerText = hoverText; } } - private getScatterContainer(): d3.Selection { - return this.dom.select('#scatter'); + private getScatterContainer(): HTMLDivElement { + return this.querySelector('#scatter') as HTMLDivElement; } private onSelectionChanged( @@ -479,8 +478,8 @@ export class Projector extends ProjectorPolymer implements this.neighborsOfFirstPoint = neighborsOfFirstPoint; let totalNumPoints = this.selectedPointIndices.length + neighborsOfFirstPoint.length; - this.statusBar.text(`Selected ${totalNumPoints} points`) - .style('display', totalNumPoints > 0 ? null : 'none'); + this.statusBar.innerText = `Selected ${totalNumPoints} points`; + this.statusBar.style.display = totalNumPoints > 0 ? null : 'none'; } setProjection(projection: Projection) { diff --git a/tensorflow/tensorboard/components/vz_sorting/BUILD b/tensorflow/tensorboard/components/vz_sorting/BUILD index ae3f6e27774..e06b8ae1979 100644 --- a/tensorflow/tensorboard/components/vz_sorting/BUILD +++ b/tensorflow/tensorboard/components/vz_sorting/BUILD @@ -1,25 +1,24 @@ -package(default_visibility = ["//tensorflow:internal"]) +package(default_visibility = ["//tensorflow/tensorboard:internal"]) -load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 -webfiles( +ts_web_library( name = "vz_sorting", srcs = [ + "sorting.ts", "vz-sorting.html", - ":ts", ], path = "/vz-sorting", visibility = ["//visibility:public"], ) -tensorboard_typescript_genrule( - name = "ts", - srcs = ["sorting.ts"], +tensorboard_webcomponent_library( + name = "legacy", + srcs = [":vz_sorting"], + destdir = "vz-sorting", ) filegroup( @@ -27,22 +26,3 @@ filegroup( srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [ - "vz-sorting.html", - ":legacy_ts", - ], - visibility = ["//visibility:public"], - destdir = "vz-sorting", -) - -tensorboard_ts_library( - name = "legacy_ts", - srcs = ["sorting.ts"], - deps = ["//tensorflow/tensorboard/components:common_deps"], -) diff --git a/tensorflow/tensorboard/components/vz_sorting/sorting.ts b/tensorflow/tensorboard/components/vz_sorting/sorting.ts index c1a656c34b8..061184d24bf 100644 --- a/tensorflow/tensorboard/components/vz_sorting/sorting.ts +++ b/tensorflow/tensorboard/components/vz_sorting/sorting.ts @@ -13,95 +13,95 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module VZ.Sorting { - /** - * Compares tag names asciinumerically broken into components. - * - *

This is the comparison function used for sorting most string values in - * TensorBoard. Unlike the standard asciibetical comparator, this function - * knows that 'a10b' > 'a2b'. Fixed point and engineering notation are - * supported. This function also splits the input by slash and underscore to - * perform array comparison. Therefore it knows that 'a/a' < 'a+/a' even - * though '+' < '/' in the ASCII table. - */ - export function compareTagNames(a, b: string): number { - let ai = 0; - let bi = 0; - while (true) { - if (ai === a.length) { - return bi === b.length ? 0 : -1; - } - if (bi === b.length) { - return 1; - } - if (isDigit(a[ai]) && isDigit(b[bi])) { - const ais = ai; - const bis = bi; - ai = consumeNumber(a, ai + 1); - bi = consumeNumber(b, bi + 1); - const an = parseFloat(a.slice(ais, ai)); - const bn = parseFloat(b.slice(bis, bi)); - if (an < bn) { - return -1; - } - if (an > bn) { - return 1; - } - continue; - } - if (isBreak(a[ai])) { - if (!isBreak(b[bi])) { - return -1; - } - } else if (isBreak(b[bi])) { - return 1; - } else if (a[ai] < b[bi]) { +/** + * Compares tag names asciinumerically broken into components. + * + *

This is the comparison function used for sorting most string values in + * TensorBoard. Unlike the standard asciibetical comparator, this function + * knows that 'a10b' > 'a2b'. Fixed point and engineering notation are + * supported. This function also splits the input by slash and underscore to + * perform array comparison. Therefore it knows that 'a/a' < 'a+/a' even + * though '+' < '/' in the ASCII table. + */ +export function compareTagNames(a, b: string): number { + let ai = 0; + let bi = 0; + while (true) { + if (ai === a.length) { + return bi === b.length ? 0 : -1; + } + if (bi === b.length) { + return 1; + } + if (isDigit(a[ai]) && isDigit(b[bi])) { + const ais = ai; + const bis = bi; + ai = consumeNumber(a, ai + 1); + bi = consumeNumber(b, bi + 1); + const an = parseFloat(a.slice(ais, ai)); + const bn = parseFloat(b.slice(bis, bi)); + if (an < bn) { return -1; - } else if (a[ai] > b[bi]) { + } + if (an > bn) { return 1; } - ai++; - bi++; + continue; } - } - - function consumeNumber(s: string, i: number): number { - enum State { NATURAL, REAL, EXPONENT_SIGN, EXPONENT } - let state = State.NATURAL; - for (; i < s.length; i++) { - if (state === State.NATURAL) { - if (s[i] === '.') { - state = State.REAL; - } else if (s[i] === 'e' || s[i] === 'E') { - state = State.EXPONENT_SIGN; - } else if (!isDigit(s[i])) { - break; - } - } else if (state === State.REAL) { - if (s[i] === 'e' || s[i] === 'E') { - state = State.EXPONENT_SIGN; - } else if (!isDigit(s[i])) { - break; - } - } else if (state === State.EXPONENT_SIGN) { - if (isDigit(s[i]) || s[i] === '+' || s[i] === '-') { - state = State.EXPONENT; - } else { - break; - } - } else if (state === State.EXPONENT) { - if (!isDigit(s[i])) { - break; - } + if (isBreak(a[ai])) { + if (!isBreak(b[bi])) { + return -1; } + } else if (isBreak(b[bi])) { + return 1; + } else if (a[ai] < b[bi]) { + return -1; + } else if (a[ai] > b[bi]) { + return 1; } - return i; - } - - function isDigit(c: string): boolean { return '0' <= c && c <= '9'; } - - function isBreak(c: string): boolean { - // TODO(jart): Remove underscore when people stop using it like a slash. - return c === '/' || c === '_' || isDigit(c); + ai++; + bi++; } } + +function consumeNumber(s: string, i: number): number { + enum State { NATURAL, REAL, EXPONENT_SIGN, EXPONENT } + let state = State.NATURAL; + for (; i < s.length; i++) { + if (state === State.NATURAL) { + if (s[i] === '.') { + state = State.REAL; + } else if (s[i] === 'e' || s[i] === 'E') { + state = State.EXPONENT_SIGN; + } else if (!isDigit(s[i])) { + break; + } + } else if (state === State.REAL) { + if (s[i] === 'e' || s[i] === 'E') { + state = State.EXPONENT_SIGN; + } else if (!isDigit(s[i])) { + break; + } + } else if (state === State.EXPONENT_SIGN) { + if (isDigit(s[i]) || s[i] === '+' || s[i] === '-') { + state = State.EXPONENT; + } else { + break; + } + } else if (state === State.EXPONENT) { + if (!isDigit(s[i])) { + break; + } + } + } + return i; +} + +function isDigit(c: string): boolean { + return '0' <= c && c <= '9'; +} + +function isBreak(c: string): boolean { + // TODO(jart): Remove underscore when people stop using it like a slash. + return c === '/' || c === '_' || isDigit(c); +} diff --git a/tensorflow/tensorboard/components/vz_sorting/test/BUILD b/tensorflow/tensorboard/components/vz_sorting/test/BUILD index f8b01b61f29..929e80d3728 100644 --- a/tensorflow/tensorboard/components/vz_sorting/test/BUILD +++ b/tensorflow/tensorboard/components/vz_sorting/test/BUILD @@ -1,35 +1,37 @@ -package(default_visibility = ["//tensorflow:internal"]) +package( + default_testonly = True, + default_visibility = ["//tensorflow/tensorboard:internal"], +) -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") -load("//tensorflow/tensorboard:defs.bzl", "tensorboard_wct_test_suite") +load("//tensorflow/tensorboard/defs:vulcanize.bzl", "tensorboard_html_binary") +load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") licenses(["notice"]) # Apache 2.0 +ts_web_library( + name = "test", + srcs = [ + "sortingTests.ts", + "tests.html", + ], + path = "/vz-sorting/test", + deps = [ + "//tensorflow/tensorboard/components/tf_imports:web_component_tester", + "//tensorflow/tensorboard/components/vz_sorting", + ], +) + +tensorboard_html_binary( + name = "devserver", + compilation_level = "WHITESPACE_ONLY", + input_path = "/vz-sorting/test/tests.html", + output_path = "/vz-sorting/test/tests.html", + deps = [":test"], +) + filegroup( name = "all_files", + testonly = 0, srcs = glob(["**"]), tags = ["notsan"], ) - -################################################################################ -# MARKED FOR DELETION - -tensorboard_wct_test_suite( - name = "legacy_test", - size = "medium", - srcs = ["index.html"], - deps = [ - "//tensorflow/tensorboard/components/vz_sorting:legacy", - "//third_party/javascript/polymer/v1/webcomponentsjs:lib", - ], -) - -tensorboard_ts_library( - name = "legacy_ts", - testonly = 1, - srcs = ["sortingTests.ts"], - deps = [ - "//tensorflow/tensorboard/components:common_deps", - "//tensorflow/tensorboard/components/vz_sorting:legacy_ts", - ], -) diff --git a/tensorflow/tensorboard/components/vz_sorting/test/index.html b/tensorflow/tensorboard/components/vz_sorting/test/index.html deleted file mode 100644 index 7148bfb4181..00000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/index.html +++ /dev/null @@ -1,23 +0,0 @@ - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts b/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts index 4dba3e35b9b..510685cb4b5 100644 --- a/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts +++ b/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts @@ -13,69 +13,65 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -module VZ.Sorting { +import {compareTagNames} from '../sorting'; + +describe('compareTagNames', () => { const assert = chai.assert; + const sortTagNames = (a) => a.sort(compareTagNames); - describe('compareTagNames', () => { - - const sortTagNames = (a) => a.sort(compareTagNames); - - it('is asciibetical', () => { - assert.deepEqual(sortTagNames(['a', 'b']), ['a', 'b']); - assert.deepEqual(sortTagNames(['a', 'B']), ['B', 'a']); - }); - - it('sorts integer portions', () => { - assert.deepEqual(['03', '1'].sort(), ['03', '1']); - assert.deepEqual(sortTagNames(['03', '1']), ['1', '03']); - assert.deepEqual(sortTagNames(['a03', 'a1']), ['a1', 'a03']); - assert.deepEqual(sortTagNames(['a03', 'b1']), ['a03', 'b1']); - assert.deepEqual(sortTagNames(['x0a03', 'x0a1']), ['x0a1', 'x0a03']); - assert.deepEqual(sortTagNames(['a/b/03', 'a/b/1']), ['a/b/1', 'a/b/03']); - }); - - it('sorts fixed point numbers', () => { - assert.deepEqual(sortTagNames(['a0.1', 'a0.01']), ['a0.01', 'a0.1']); - }); - - it('sorts engineering notation', () => { - assert.deepEqual(sortTagNames(['a1e9', 'a9e8']), ['a9e8', 'a1e9']); - assert.deepEqual(sortTagNames(['a1e+9', 'a9e+8']), ['a9e+8', 'a1e+9']); - assert.deepEqual(sortTagNames(['a1e+5', 'a9e-6']), ['a9e-6', 'a1e+5']); - assert.deepEqual( - sortTagNames(['a1.0e9', 'a9.0e8']), ['a9.0e8', 'a1.0e9']); - assert.deepEqual( - sortTagNames(['a1.0e+9', 'a9.0e+8']), ['a9.0e+8', 'a1.0e+9']); - }); - - it('is componentized by slash', () => { - assert.deepEqual(['a+/a', 'a/a', 'ab/a'].sort(), ['a+/a', 'a/a', 'ab/a']); - assert.deepEqual( - sortTagNames(['a+/a', 'a/a', 'ab/a']), ['a/a', 'a+/a', 'ab/a']); - }); - - it('is componentized by underscore', () => { - assert.deepEqual( - sortTagNames(['a+_a', 'a_a', 'ab_a']), ['a_a', 'a+_a', 'ab_a']); - assert.deepEqual( - sortTagNames(['a+/a', 'a_a', 'ab_a']), ['a_a', 'a+/a', 'ab_a']); - }); - - it('is componentized by number boundaries', () => { - assert.deepEqual( - sortTagNames(['a+0a', 'a0a', 'ab0a']), ['a0a', 'a+0a', 'ab0a']); - }); - - it('empty comes first', () => { - assert.deepEqual( - sortTagNames(['a', '//', '/', '']), ['', '/', '//', 'a']); - }); - - it('decimal parsed correctly', () => { - assert.deepEqual(sortTagNames(['0.2', '0.03']), ['0.03', '0.2']); - assert.deepEqual(sortTagNames(['0..2', '0..03']), ['0..2', '0..03']); - assert.deepEqual(sortTagNames(['.2', '.03']), ['.2', '.03']); - }); + it('is asciibetical', () => { + assert.deepEqual(sortTagNames(['a', 'b']), ['a', 'b']); + assert.deepEqual(sortTagNames(['a', 'B']), ['B', 'a']); }); -} + + it('sorts integer portions', () => { + assert.deepEqual(['03', '1'].sort(), ['03', '1']); + assert.deepEqual(sortTagNames(['03', '1']), ['1', '03']); + assert.deepEqual(sortTagNames(['a03', 'a1']), ['a1', 'a03']); + assert.deepEqual(sortTagNames(['a03', 'b1']), ['a03', 'b1']); + assert.deepEqual(sortTagNames(['x0a03', 'x0a1']), ['x0a1', 'x0a03']); + assert.deepEqual(sortTagNames(['a/b/03', 'a/b/1']), ['a/b/1', 'a/b/03']); + }); + + it('sorts fixed point numbers', () => { + assert.deepEqual(sortTagNames(['a0.1', 'a0.01']), ['a0.01', 'a0.1']); + }); + + it('sorts engineering notation', () => { + assert.deepEqual(sortTagNames(['a1e9', 'a9e8']), ['a9e8', 'a1e9']); + assert.deepEqual(sortTagNames(['a1e+9', 'a9e+8']), ['a9e+8', 'a1e+9']); + assert.deepEqual(sortTagNames(['a1e+5', 'a9e-6']), ['a9e-6', 'a1e+5']); + assert.deepEqual(sortTagNames(['a1.0e9', 'a9.0e8']), ['a9.0e8', 'a1.0e9']); + assert.deepEqual( + sortTagNames(['a1.0e+9', 'a9.0e+8']), ['a9.0e+8', 'a1.0e+9']); + }); + + it('is componentized by slash', () => { + assert.deepEqual(['a+/a', 'a/a', 'ab/a'].sort(), ['a+/a', 'a/a', 'ab/a']); + assert.deepEqual( + sortTagNames(['a+/a', 'a/a', 'ab/a']), ['a/a', 'a+/a', 'ab/a']); + }); + + it('is componentized by underscore', () => { + assert.deepEqual( + sortTagNames(['a+_a', 'a_a', 'ab_a']), ['a_a', 'a+_a', 'ab_a']); + assert.deepEqual( + sortTagNames(['a+/a', 'a_a', 'ab_a']), ['a_a', 'a+/a', 'ab_a']); + }); + + it('is componentized by number boundaries', () => { + assert.deepEqual( + sortTagNames(['a+0a', 'a0a', 'ab0a']), ['a0a', 'a+0a', 'ab0a']); + }); + + it('empty comes first', () => { + assert.deepEqual(sortTagNames(['a', '//', '/', '']), ['', '/', '//', 'a']); + }); + + it('decimal parsed correctly', () => { + assert.deepEqual(sortTagNames(['0.2', '0.03']), ['0.03', '0.2']); + assert.deepEqual(sortTagNames(['0..2', '0..03']), ['0..2', '0..03']); + assert.deepEqual(sortTagNames(['.2', '.03']), ['.2', '.03']); + }); +}); diff --git a/tensorflow/tensorboard/components/vz_sorting/test/tests.html b/tensorflow/tensorboard/components/vz_sorting/test/tests.html new file mode 100644 index 00000000000..f92c608cdb1 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_sorting/test/tests.html @@ -0,0 +1,23 @@ + + + + + + + + diff --git a/tensorflow/tensorboard/defs.bzl b/tensorflow/tensorboard/defs.bzl deleted file mode 100644 index 3488978ab2d..00000000000 --- a/tensorflow/tensorboard/defs.bzl +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_DEFAULT_TYPINGS = [ - "@com_microsoft_typescript//:lib.es6.d.ts", -] - -def tensorboard_typescript_genrule(name, srcs, typings=[], **kwargs): - """Filegroup of compiled TypeScript sources. - - This is a very unsophisticated TypeScript rule where the user is responsible - for passing all typings and sources via srcs. It's meant as a stopgap because - TypeScript rules currently don't exist for Bazel. The definition of this rule - will need to evolve as more ts_library rules are migrated. - """ - for src in srcs: - if (src.startswith("/") or - src.endswith(".d.ts") or - not src.endswith(".ts")): - fail("srcs must be typescript sources in same package") - native.genrule( - name = name, - srcs = _DEFAULT_TYPINGS + typings + srcs, - outs = [src[:-3] + ".js" for src in srcs], - cmd = "$(location @com_microsoft_typescript//:tsc)" + - " --inlineSourceMap" + - " --inlineSources" + - " --outDir $(@D)" + - " $(SRCS)", - tools = ["@com_microsoft_typescript//:tsc"], - **kwargs - ) - -def tensorboard_ts_library(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_webcomponent_library(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_wct_test_suite(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass diff --git a/tensorflow/tensorboard/defs/BUILD b/tensorflow/tensorboard/defs/BUILD new file mode 100644 index 00000000000..92a2af34048 --- /dev/null +++ b/tensorflow/tensorboard/defs/BUILD @@ -0,0 +1,14 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "ts_web_library_default_typings", + srcs = [ + # Ordering probably matters. + "@com_microsoft_typescript//:lib.es6.d.ts", + "@io_angular_clutz//:src/resources/closure.lib.d.ts", + "clutz.d.ts", + ], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/tensorboard/defs/clutz.d.ts b/tensorflow/tensorboard/defs/clutz.d.ts new file mode 100644 index 00000000000..47cf307d261 --- /dev/null +++ b/tensorflow/tensorboard/defs/clutz.d.ts @@ -0,0 +1,19 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// tslint:disable +declare namespace ಠ_ಠ.clutz { + interface IteratorIterable extends Iterator, Iterable {} + interface IIterableResult extends IteratorResult {} +} diff --git a/tensorflow/tensorboard/defs/defs.bzl b/tensorflow/tensorboard/defs/defs.bzl new file mode 100644 index 00000000000..94e2d7c540f --- /dev/null +++ b/tensorflow/tensorboard/defs/defs.bzl @@ -0,0 +1,24 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def tensorboard_webcomponent_library(**kwargs): + """Rules referencing this will be deleted from the codebase soon.""" + pass + +def _legacy_js_impl(target, ctx): + return struct() + +legacy_js = aspect( + implementation=_legacy_js_impl, + attr_aspects=["exports"]) diff --git a/tensorflow/tensorboard/defs/hacks.bzl b/tensorflow/tensorboard/defs/hacks.bzl new file mode 100644 index 00000000000..f1d4be79061 --- /dev/null +++ b/tensorflow/tensorboard/defs/hacks.bzl @@ -0,0 +1,80 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(jart): Merge this file into defs.bzl once that file is sync unified. + +def tensorboard_typescript_bundle( + name, + out, + namespace_srcs, + namespace_symbol_aliases={}, + namespace_symbol_aliases_public={}, + **kwargs): + """Rolls TypeScript ES6 modules into one vanilla source file without imports. + + This is a genrule wrapper that concatenates TypeScripts sources inside + namespace blocks while removing ^import lines. Because the sources themselves + are not parsed, the structure of the modules must be passed to this macro as + a Skylark data structure. + + Args: + name: Name of this build rule target. + out: Path of outputted TypeScript source file. + namespace_srcs: Multimap of namespace strings to build file targets. The + ordering of the dictionary and nested lists does not matter when + generating a typings file, but *does* matter when generating a source + file. + namespace_symbol_aliases: Map of namespace strings where each value is a + map of symbol names to fully qualified symbol names. + namespace_symbol_aliases_public: Same as namespace_symbol_aliases but the + symbol will be visible to other namespaces. + """ + cmd = ["(", "echo // GENERATED BY TENSORBOARD_TYPESCRIPT_BUNDLE"] + inputs = set() + for namespace, srcs in namespace_srcs.items(): + cmd.append("echo") + if out[-5:] == ".d.ts": + cmd.append("echo 'declare namespace %s {'" % namespace) + elif out[-3:] == ".ts": + cmd.append("echo 'module %s {'" % namespace) + else: + fail("'out' must end with .ts or .d.ts: " + out) + for symbol, canon in namespace_symbol_aliases.get(namespace, {}).items(): + cmd.append("echo 'import %s = %s;'" % (symbol, canon)) + for symbol, canon in namespace_symbol_aliases_public.get(namespace, + {}).items(): + cmd.append("echo 'export import %s = %s;'" % (symbol, canon)) + inputs += srcs + for src in srcs: + cmd.append("for f in $(locations %s); do" % src) + cmd.append(" echo") + cmd.append(" echo /////////////////////////////////////////////////////") + cmd.append(" echo // " + namespace) + cmd.append(" echo // $$f") + cmd.append(" echo /////////////////////////////////////////////////////") + cmd.append(" echo") + cmd.append(" sed 's!^import !// import !' $$f \\") + cmd.append(" | sed 's!^export declare !export !' \\") + cmd.append(" | sed '/^export .* from /d' \\") + cmd.append(" | sed '/^export {.*};$$/d'") + cmd.append("done") + cmd.append("echo '}'") + cmd.append(") >$@") + native.genrule( + name = name, + srcs = list(inputs), + outs = [out], + cmd = "\n".join(cmd), + **kwargs + ) diff --git a/tensorflow/tensorboard/defs/protos.bzl b/tensorflow/tensorboard/defs/protos.bzl new file mode 100644 index 00000000000..6d1982e098d --- /dev/null +++ b/tensorflow/tensorboard/defs/protos.bzl @@ -0,0 +1,27 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@protobuf//:protobuf.bzl", "py_proto_library") + +def tb_proto_library(name, srcs = [], visibility = []): + py_proto_library( + name = name + "_py", + srcs = srcs, + srcs_version = "PY2AND3", + deps = ["@protobuf//:protobuf_python"], + protoc = "@protobuf//:protoc", + visibility = visibility, + default_runtime = "@protobuf//:protobuf_python", + testonly = 0, + ) \ No newline at end of file diff --git a/tensorflow/tensorboard/defs/vulcanize.bzl b/tensorflow/tensorboard/defs/vulcanize.bzl new file mode 100644 index 00000000000..6ff49a35ed7 --- /dev/null +++ b/tensorflow/tensorboard/defs/vulcanize.bzl @@ -0,0 +1,125 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tensorflow/tensorboard/defs:defs.bzl", "legacy_js") +load("@io_bazel_rules_closure//closure/private:defs.bzl", "collect_js", "unfurl", "long_path") +load("//tensorflow/tensorboard/defs:web.bzl", "web_aspect") + +def _tensorboard_html_binary(ctx): + deps = unfurl(ctx.attr.deps, provider="webfiles") + manifests = set(order="topological") + files = set() + webpaths = set() + for dep in deps: + manifests += dep.webfiles.manifests + webpaths += dep.webfiles.webpaths + files += dep.data_runfiles.files + webpaths += [ctx.attr.output_path] + closure_js_library=collect_js( + ctx, unfurl(ctx.attr.deps, provider="closure_js_library")) + + # vulcanize + jslibs = depset(ctx.files._jslibs) + closure_js_library.srcs + ctx.action( + inputs=list(manifests | files | jslibs), + outputs=[ctx.outputs.html], + executable=ctx.executable._Vulcanize, + arguments=([ctx.attr.compilation_level, + "true" if ctx.attr.testonly else "false", + ctx.attr.input_path, + ctx.attr.output_path, + ctx.outputs.html.path] + + [f.path for f in jslibs] + + [f.path for f in manifests]), + progress_message="Vulcanizing %s" % ctx.attr.input_path) + + # webfiles manifest + manifest_srcs = [struct(path=ctx.outputs.html.path, + longpath=long_path(ctx, ctx.outputs.html), + webpath=ctx.attr.output_path)] + manifest = ctx.new_file(ctx.configuration.bin_dir, + "%s.pbtxt" % ctx.label.name) + ctx.file_action( + output=manifest, + content=struct( + label=str(ctx.label), + src=manifest_srcs).to_proto()) + manifests += [manifest] + + # webfiles server + params = struct( + label=str(ctx.label), + bind="[::]:6006", + manifest=[long_path(ctx, man) for man in manifests], + external_asset=[struct(webpath=k, path=v) + for k, v in ctx.attr.external_assets.items()]) + params_file = ctx.new_file(ctx.configuration.bin_dir, + "%s_server_params.pbtxt" % ctx.label.name) + ctx.file_action(output=params_file, content=params.to_proto()) + ctx.file_action( + executable=True, + output=ctx.outputs.executable, + content="#!/bin/sh\nexec %s %s" % ( + ctx.executable._WebfilesServer.short_path, + long_path(ctx, params_file))) + + transitive_runfiles = depset() + transitive_runfiles += ctx.attr._WebfilesServer.data_runfiles.files + for dep in deps: + transitive_runfiles += dep.data_runfiles.files + return struct( + files=depset([ctx.outputs.html]), + webfiles=struct( + manifest=manifest, + manifests=manifests, + webpaths=webpaths, + dummy=ctx.outputs.html), + runfiles=ctx.runfiles( + files=ctx.files.data + [manifest, + params_file, + ctx.outputs.html, + ctx.outputs.executable], + transitive_files=transitive_runfiles)) + +tensorboard_html_binary = rule( + implementation=_tensorboard_html_binary, + executable=True, + attrs={ + "compilation_level": attr.string(default="ADVANCED"), + "input_path": attr.string(mandatory=True), + "output_path": attr.string(mandatory=True), + "data": attr.label_list(cfg="data", allow_files=True), + "deps": attr.label_list( + aspects=[ + web_aspect, + legacy_js, + ], + mandatory=True), + "external_assets": attr.string_dict(default={"/_/runfiles": "."}), + "_jslibs": attr.label( + default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:jslibs"), + allow_files=True), + "_Vulcanize": attr.label( + default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:Vulcanize"), + executable=True, + cfg="host"), + "_WebfilesServer": attr.label( + default=Label( + "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles/server:WebfilesServer"), + executable=True, + cfg="host"), + }, + outputs={ + "html": "%{name}.html", + }) diff --git a/tensorflow/tensorboard/defs/web.bzl b/tensorflow/tensorboard/defs/web.bzl new file mode 100644 index 00000000000..103942b0a25 --- /dev/null +++ b/tensorflow/tensorboard/defs/web.bzl @@ -0,0 +1,419 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Same as web_library but supports TypeScript.""" + +load("//tensorflow/tensorboard/defs:defs.bzl", "legacy_js") + +load("//third_party:clutz.bzl", + "CLUTZ_ATTRIBUTES", + "CLUTZ_OUTPUTS", + "clutz_aspect", + "extract_dts_from_closure_libraries") + +load("@io_bazel_rules_closure//closure/private:defs.bzl", + "CLOSURE_LIBRARY_BASE_ATTR", + "CLOSURE_LIBRARY_DEPS_ATTR", + "collect_js", + "collect_runfiles", + "convert_path_to_es6_module_name", + "create_argfile", + "difference", + "long_path", + "unfurl") + +_ASPECT_SLURP_FILE_TYPE = FileType([ + ".html", ".js", ".css", ".gss", ".png", ".jpg", ".gif", ".ico", ".svg"]) + +_CLOSURE_WORKER = attr.label( + default=Label("@io_bazel_rules_closure//java/io/bazel/rules/closure:ClosureWorker"), + executable=True, + cfg="host") + +def _ts_web_library(ctx): + if not ctx.attr.srcs: + if ctx.attr.deps: + fail("deps can not be set when srcs is not") + if not ctx.attr.exports: + fail("exports must be set if srcs is not") + if ctx.attr.path: + if not ctx.attr.path.startswith("/"): + fail("webpath must start with /") + if ctx.attr.path != "/" and ctx.attr.path.endswith("/"): + fail("webpath must not end with / unless it is /") + if "//" in ctx.attr.path: + fail("webpath must not have //") + elif ctx.attr.srcs: + fail("path must be set when srcs is set") + if "*" in ctx.attr.suppress and len(ctx.attr.suppress) != 1: + fail("when \"*\" is suppressed no other items should be present") + + # process what came before + deps = unfurl(ctx.attr.deps, provider="webfiles") + webpaths = depset() + ts_typings = depset(ctx.files._default_typings) + ts_typings_paths = depset( + [long_path(ctx, f) for f in ctx.files._default_typings]) + ts_typings_execroots = depset() + aspect_runfiles = depset() + for dep in deps: + webpaths += dep.webfiles.webpaths + if hasattr(dep.webfiles, "ts_typings"): + ts_typings += dep.webfiles.ts_typings + if hasattr(dep.webfiles, "ts_typings_paths"): + ts_typings_paths += dep.webfiles.ts_typings_paths + if hasattr(dep.webfiles, "ts_typings_execroots"): + ts_typings_execroots += dep.webfiles.ts_typings_execroots + if hasattr(dep.webfiles, "aspect_runfiles"): + aspect_runfiles += dep.webfiles.aspect_runfiles + + # process what comes now + manifest_srcs = [] + new_webpaths = [] + ts_inputs = depset() + ts_outputs = [] + ts_files = list(ts_typings_paths) + new_typings = [] + new_typings_paths = [] + new_typings_execroot = struct(inputs=[]) + execroot = struct( + inputs=[(long_path(ctx, f), f.path) for f in ctx.files._default_typings], + outputs=[], + program=[ctx.executable._tsc.path, "-p"]) + web_srcs = [] + path = ctx.attr.path + strip = _get_strip(ctx) + for src in ctx.files.srcs: + suffix = _get_path_relative_to_package(src) + if strip: + if not suffix.startswith(strip): + fail("Relative src path not start with '%s': %s" % (strip, suffix)) + suffix = suffix[len(strip):] + webpath = "%s/%s" % ("" if path == "/" else path, suffix) + _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs) + if suffix.endswith(".d.ts"): + web_srcs.append(src) + entry = (webpath[1:], src.path) + new_typings.append(src) + new_typings_paths.append(entry[0]) + new_typings_execroot.inputs.append(entry) + ts_inputs += [src] + ts_files.append(entry[0]) + execroot.inputs.append(entry) + elif suffix.endswith(".ts"): + noext = suffix[:-3] + js = ctx.new_file(ctx.bin_dir, "%s.js" % noext) + dts = ctx.new_file(ctx.bin_dir, "%s.d.ts" % noext) + webpath_js = webpath[:-3] + ".js" + webpath_dts = webpath[:-3] + ".d.ts" + _add_webpath(ctx, js, webpath_js, webpaths, new_webpaths, manifest_srcs) + _add_webpath(ctx, dts, webpath_dts, webpaths, new_webpaths, manifest_srcs) + ts_inputs += [src] + ts_outputs.append(js) + ts_outputs.append(dts) + web_srcs.append(dts) + web_srcs.append(js) + ts_files.append(webpath[1:]) + execroot.inputs.append((webpath[1:], src.path)) + execroot.outputs.append((webpath_js[1:], js.path)) + execroot.outputs.append((webpath_dts[1:], dts.path)) + new_typings.append(dts) + new_typings_paths.append(webpath_dts[1:]) + new_typings_execroot.inputs.append((webpath_dts[1:], dts.path)) + else: + web_srcs.append(src) + + # get typings for closure code + clutz_dts = extract_dts_from_closure_libraries(ctx) + if clutz_dts: + entry = (long_path(ctx, clutz_dts), clutz_dts.path) + ts_inputs += [clutz_dts] + ts_files.append(entry[0]) + execroot.inputs.append(entry) + + # compile typescript + workspace = "" + if ctx.label.workspace_root: + workspace = "/" + ctx.label.workspace_root + if execroot.outputs: + ts_config = _new_file(ctx, "-tsc.json") + execroot.inputs.append(("tsconfig.json", ts_config.path)) + ctx.file_action( + output=ts_config, + content=struct( + compilerOptions=struct( + baseUrl=".", + declaration=True, + inlineSourceMap=True, + inlineSources=True, + module="es6", + moduleResolution="node", + noResolve=True, + target="es5", + ), + files=ts_files, + ).to_json()) + er_config = _new_file(ctx, "-tsc-execroot.json") + ctx.file_action(output=er_config, content=execroot.to_json()) + ts_inputs += collect_runfiles([ctx.attr._tsc]) + ts_inputs += ctx.files._tsc + ts_inputs += ts_typings + ts_inputs += ts_typings_execroots + ts_inputs += [ts_config, er_config] + ctx.action( + inputs=list(ts_inputs), + outputs=ts_outputs, + executable=ctx.executable._execrooter, + arguments=[er_config.path] + [f.path for f in ts_typings_execroots], + progress_message="Compiling %d TypeScript files %s" % ( + len(ts_files), ctx.label)) + + # perform strict dependency checking + manifest = _make_manifest(ctx, manifest_srcs) + webpaths += new_webpaths + dummy, manifests = _run_webfiles_validator(ctx, web_srcs, deps, manifest) + web_srcs.append(dummy) + + # define development web server that only applies to this transitive closure + params = struct( + label=str(ctx.label), + bind="[::]:6006", + manifest=[long_path(ctx, man) for man in manifests], + external_asset=[struct(webpath=k, path=v) + for k, v in ctx.attr.external_assets.items()]) + params_file = _new_file(ctx, "-params.pbtxt") + ctx.file_action(output=params_file, content=params.to_proto()) + ctx.file_action( + executable=True, + output=ctx.outputs.executable, + content="#!/bin/sh\nexec %s %s" % ( + ctx.executable._WebfilesServer.short_path, + long_path(ctx, params_file))) + + if new_typings: + er_config = _new_file(ctx, "-typings-execroot.json") + ctx.file_action(output=er_config, content=new_typings_execroot.to_json()) + ts_typings += new_typings + ts_typings_paths += new_typings_paths + ts_typings_execroots += [er_config] + else: + ts_typings = depset() + ts_typings_paths = depset() + ts_typings_execroots = depset() + + # export data to parent rules + return struct( + files=depset(web_srcs + [dummy]), + exports=unfurl(ctx.attr.exports), + webfiles=struct( + manifest=manifest, + manifests=manifests, + webpaths=webpaths, + dummy=dummy, + ts_typings=ts_typings, + ts_typings_paths=ts_typings_paths, + ts_typings_execroots=ts_typings_execroots), + closure_js_library=collect_js( + ctx, unfurl(ctx.attr.deps, provider="closure_js_library")), + runfiles=ctx.runfiles( + files=ctx.files.srcs + ctx.files.data + ts_outputs + [ + manifest, + params_file, + ctx.outputs.executable, + dummy], + transitive_files=(collect_runfiles([ctx.attr._WebfilesServer]) | + collect_runfiles(deps) | + collect_runfiles(ctx.attr.data) | + aspect_runfiles))) + +def _web_aspect_impl(target, ctx): + if hasattr(target, "webfiles"): + return struct() + srcs = [] + deps = [] + if hasattr(ctx.rule.files, "srcs"): + srcs.extend(_ASPECT_SLURP_FILE_TYPE.filter(ctx.rule.files.srcs)) + for attr in ("deps", "sticky_deps", "module_deps"): + value = getattr(ctx.rule.attr, attr, None) + if value: + deps.extend(value) + deps = unfurl(deps, provider="webfiles") + webpaths = depset() + aspect_runfiles = depset(srcs) + for dep in deps: + webpaths += dep.webfiles.webpaths + if hasattr(dep.webfiles, "aspect_runfiles"): + aspect_runfiles += dep.webfiles.aspect_runfiles + manifest_srcs = [] + new_webpaths = [] + for src in srcs: + webpath = "/" + long_path(ctx, src) + _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs) + webpaths += new_webpaths + manifest = _make_manifest(ctx, manifest_srcs) + dummy, manifests = _run_webfiles_validator(ctx, srcs, deps, manifest) + aspect_runfiles += [dummy, manifest] + return struct( + webfiles=struct( + manifest=manifest, + manifests=manifests, + webpaths=webpaths, + dummy=dummy, + aspect_runfiles=aspect_runfiles)) + +def _make_manifest(ctx, src_list): + manifest = _new_file(ctx, "-webfiles.pbtxt") + ctx.file_action( + output=manifest, + content=struct( + label=str(ctx.label), + src=src_list).to_proto()) + return manifest + +def _run_webfiles_validator(ctx, srcs, deps, manifest): + dummy = _new_file(ctx, "-webfiles.ignoreme") + manifests = depset(order="topological") + for dep in deps: + manifests += dep.webfiles.manifests + if srcs: + args = ["WebfilesValidator", + "--dummy", dummy.path, + "--target", manifest.path] + if hasattr(ctx, "attr") and hasattr(ctx.attr, "suppress"): + for category in ctx.attr.suppress: + args.append("--suppress") + args.append(category) + inputs = [manifest] + inputs.extend(srcs) + direct_manifests = depset() + for dep in deps: + inputs.append(dep.webfiles.dummy) + for f in dep.files: + inputs.append(f) + direct_manifests += [dep.webfiles.manifest] + inputs.append(dep.webfiles.manifest) + args.append("--direct_dep") + args.append(dep.webfiles.manifest.path) + for man in difference(manifests, direct_manifests): + inputs.append(man) + args.append("--transitive_dep") + args.append(man.path) + argfile = _new_file(ctx, "-webfiles-checker-args.txt") + ctx.file_action(output=argfile, content="\n".join(args)) + inputs.append(argfile) + ctx.action( + inputs=inputs, + outputs=[dummy], + executable=(getattr(ctx.executable, "_ClosureWorker", None) or + getattr(ctx.executable, "_ClosureWorkerAspect", None)), + arguments=["@@" + argfile.path], + mnemonic="Closure", + execution_requirements={"supports-workers": "1"}, + progress_message="Checking webfiles %s" % ctx.label) + else: + ctx.file_action(output=dummy, content="BOO!") + manifests += [manifest] + return dummy, manifests + +def _new_file(ctx, suffix): + return ctx.new_file(ctx.bin_dir, "%s%s" % (ctx.label.name, suffix)) + +def _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs): + if webpath in new_webpaths: + _fail(ctx, "multiple srcs within %s define the webpath %s " % ( + ctx.label, webpath)) + if webpath in webpaths: + _fail(ctx, "webpath %s was defined by %s when already defined by deps" % ( + webpath, ctx.label)) + new_webpaths.append(webpath) + manifest_srcs.append(struct( + path=src.path, + longpath=long_path(ctx, src), + webpath=webpath)) + +def _fail(ctx, message): + if ctx.attr.suppress == ["*"]: + print(message) + else: + fail(message) + +def _get_path_relative_to_package(artifact): + """Returns file path relative to the package that declared it.""" + path = artifact.path + for prefix in (artifact.root.path, + artifact.owner.workspace_root if artifact.owner else '', + artifact.owner.package if artifact.owner else ''): + if prefix: + prefix = prefix + "/" + if not path.startswith(prefix): + fail("Path %s doesn't start with %s" % (path, prefix)) + path = path[len(prefix):] + return path + +def _get_strip(ctx): + strip = ctx.attr.strip_prefix + if strip: + if strip.startswith("/"): + _fail(ctx, "strip_prefix should not end with /") + strip = strip[1:] + if strip.endswith("/"): + _fail(ctx, "strip_prefix should not end with /") + else: + strip += "/" + return strip + +web_aspect = aspect( + implementation=_web_aspect_impl, + attr_aspects=["deps", "sticky_deps", "module_deps"], + attrs={"_ClosureWorkerAspect": _CLOSURE_WORKER}) + +ts_web_library = rule( + implementation=_ts_web_library, + executable=True, + attrs=CLUTZ_ATTRIBUTES + { + "path": attr.string(), + "srcs": attr.label_list(allow_files=True), + "deps": attr.label_list( + aspects=[ + web_aspect, + clutz_aspect, + legacy_js, + ]), + "exports": attr.label_list(), + "data": attr.label_list(cfg="data", allow_files=True), + "suppress": attr.string_list(), + "strip_prefix": attr.string(), + "external_assets": attr.string_dict(default={"/_/runfiles": "."}), + "clutz_entry_points": attr.string_list(), + "_execrooter": attr.label( + default=Label("//tensorflow/tensorboard/scripts:execrooter"), + executable=True, + cfg="host"), + "_tsc": attr.label( + default=Label("@com_microsoft_typescript//:tsc"), + allow_files=True, + executable=True, + cfg="host"), + "_default_typings": attr.label( + default=Label("//tensorflow/tensorboard:ts_web_library_default_typings"), + allow_files=True), + "_WebfilesServer": attr.label( + default=Label("@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles/server:WebfilesServer"), + executable=True, + cfg="host"), + "_ClosureWorker": _CLOSURE_WORKER, + "_closure_library_base": CLOSURE_LIBRARY_BASE_ATTR, + "_closure_library_deps": CLOSURE_LIBRARY_DEPS_ATTR, + }, + outputs=CLUTZ_OUTPUTS) diff --git a/tensorflow/tensorboard/defs/zipper.bzl b/tensorflow/tensorboard/defs/zipper.bzl new file mode 100644 index 00000000000..e98309ec9a5 --- /dev/null +++ b/tensorflow/tensorboard/defs/zipper.bzl @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@io_bazel_rules_closure//closure/private:defs.bzl", "unfurl", "long_path") + +def _tensorboard_zip_file(ctx): + deps = unfurl(ctx.attr.deps, provider="webfiles") + manifests = set(order="link") + files = set() + webpaths = set() + for dep in deps: + manifests += dep.webfiles.manifests + webpaths += dep.webfiles.webpaths + files += dep.data_runfiles.files + ctx.action( + inputs=list(manifests + files), + outputs=[ctx.outputs.zip], + executable=ctx.executable._Zipper, + arguments=([ctx.outputs.zip.path] + + [m.path for m in manifests]), + progress_message="Zipping %d files" % len(webpaths)) + transitive_runfiles = set() + for dep in deps: + transitive_runfiles += dep.data_runfiles.files + return struct( + files=set([ctx.outputs.zip]), + runfiles=ctx.runfiles( + files=ctx.files.data + [ctx.outputs.zip], + transitive_files=transitive_runfiles)) + +tensorboard_zip_file = rule( + implementation=_tensorboard_zip_file, + attrs={ + "data": attr.label_list(cfg="data", allow_files=True), + "deps": attr.label_list(providers=["webfiles"], mandatory=True), + "_Zipper": attr.label( + default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:Zipper"), + executable=True, + cfg="host"), + }, + outputs={ + "zip": "%{name}.zip", + }) diff --git a/tensorflow/tensorboard/demo/BUILD b/tensorflow/tensorboard/demo/BUILD new file mode 100644 index 00000000000..b253572ec55 --- /dev/null +++ b/tensorflow/tensorboard/demo/BUILD @@ -0,0 +1,20 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") + +licenses(["notice"]) # Apache 2.0 + +# THIS PACKAGE HAS MOVED +# See tensorflow/tensorboard/components/tf_tensorboard:demo + +web_library( + name = "demo_data", + srcs = glob(["data/**"]), + path = "/", +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/demo/data/logdir b/tensorflow/tensorboard/demo/data/logdir new file mode 100644 index 00000000000..b6362b45d77 --- /dev/null +++ b/tensorflow/tensorboard/demo/data/logdir @@ -0,0 +1 @@ +{"logdir": "/foo/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/index.html b/tensorflow/tensorboard/demo/index.html deleted file mode 100644 index 581f8a27235..00000000000 --- a/tensorflow/tensorboard/demo/index.html +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/dist/bazel-html-imports.html b/tensorflow/tensorboard/dist/bazel-html-imports.html deleted file mode 100644 index 2268e6d7d4c..00000000000 --- a/tensorflow/tensorboard/dist/bazel-html-imports.html +++ /dev/null @@ -1,23 +0,0 @@ - - - - - - - - diff --git a/tensorflow/tensorboard/dist/index.html b/tensorflow/tensorboard/dist/index.html deleted file mode 100644 index 66fce9fe9af..00000000000 --- a/tensorflow/tensorboard/dist/index.html +++ /dev/null @@ -1,32 +0,0 @@ - - - - - - TensorBoard - - - - - - - - - - - diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html deleted file mode 100644 index 3e077d1a73a..00000000000 --- a/tensorflow/tensorboard/dist/tf-tensorboard.html +++ /dev/null @@ -1,24940 +0,0 @@ - - - -

- - - - - - \ No newline at end of file diff --git a/tensorflow/tensorboard/gulp_tasks/bower.js b/tensorflow/tensorboard/gulp_tasks/bower.js deleted file mode 100644 index 7c0e515c6c9..00000000000 --- a/tensorflow/tensorboard/gulp_tasks/bower.js +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -var gulp = require('gulp'); -var bower = require('gulp-bower'); - -module.exports = function() { - return function() { - return bower(); - } -} diff --git a/tensorflow/tensorboard/gulp_tasks/compile.js b/tensorflow/tensorboard/gulp_tasks/compile.js deleted file mode 100644 index 3d0d725cfb2..00000000000 --- a/tensorflow/tensorboard/gulp_tasks/compile.js +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -var gulp = require('gulp'); -var ts = require('gulp-typescript'); -var typescript = require('typescript'); -var gutil = require('gulp-util'); -var filter = require('gulp-filter'); -var merge = require('merge2'); -var browserify = require('browserify'); -var tsify = require('tsify'); -var source = require('vinyl-source-stream'); -var glob = require('glob').sync; -var concat = require('gulp-concat'); - -var tsProject = ts.createProject('./tsconfig.json', { - typescript: typescript, - noExternalResolve: true, // opt-in for faster compilation! -}); - -/** List of components (and their external deps) that are using es6 modules. */ -var ES6_COMPONENTS = [{ - name: 'vz_projector', - deps: [ - 'd3/d3.min.js', 'weblas/dist/weblas.js', 'three.js/build/three.min.js', - 'three.js/examples/js/controls/OrbitControls.js', - 'numericjs/lib/numeric-1.2.6.js' - ] -}]; - -module.exports = function(includeDeps) { - return function() { - // Compile all components that are using ES6 modules into a bundle.js - // using browserify. - var entries = ['typings/index.d.ts']; - var deps = {}; - ES6_COMPONENTS.forEach(function(component) { - // Collect all the typescript files across the components. - entries = entries.concat(glob( - 'components/' + component.name + '/**/*.ts', - // Do not include tests or IDE-purposed files. - {ignore: ['**/*_test.ts', '**/deps.d.ts']})); - // Collect the unique external deps across all components using es6 - // modules. - component.deps.forEach(function(dep) { - deps['components/' + dep] = true; - }); - }); - deps = Object.keys(deps); - - // Compile, bundle all the typescript files and prepend their deps. - browserify(entries) - .plugin(tsify) - .bundle() - .on('error', function(error) { console.error(error.toString()); }) - .pipe(source('bundle.js')) - .pipe(gulp.dest('components')) - .on('end', function() { - // Typescript was compiled and bundled. Now we need to prepend - // the external dependencies. - if (includeDeps) { - gulp.src(deps.concat(['components/bundle.js'])) - .pipe(concat('bundle.js')) - .pipe(gulp.dest('components')); - } - }); - - // Compile components that are using global namespaces producing 1 js file - // for each ts file. - var isComponent = filter([ - 'components/tf_*/**/*.ts', 'components/vz_*/**/*.ts', 'typings/**/*.ts', - 'components/plottable/plottable.d.ts' - // Ignore components that use es6 modules. - ].concat(ES6_COMPONENTS.map(function(component) { - return '!components/' + component.name + '/**/*.ts'; - }))); - - return tsProject.src() - .pipe(isComponent) - .pipe(ts(tsProject)) - .js.pipe(gulp.dest('.')); - }; -}; diff --git a/tensorflow/tensorboard/gulp_tasks/test.js b/tensorflow/tensorboard/gulp_tasks/test.js deleted file mode 100644 index ffa8122c7b5..00000000000 --- a/tensorflow/tensorboard/gulp_tasks/test.js +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -var gulp = require('gulp'); -var tester = require('web-component-tester').test; - -module.exports = function(done) { - tester({}, function(error) { - if (error) { - // Pretty error for gulp. - error = new Error(error.message || error); - error.showStack = false; - } - done(error); - }); -} diff --git a/tensorflow/tensorboard/gulp_tasks/util.js b/tensorflow/tensorboard/gulp_tasks/util.js deleted file mode 100644 index 7a1d2a58ab6..00000000000 --- a/tensorflow/tensorboard/gulp_tasks/util.js +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -var fs = require('fs'); -var path = require('path'); - -/** - * Returns a list of web components inside the components directory for which - * the name predicate is true. - */ -exports.getComponents = function(namePredicate) { - return fs.readdirSync('components') - .filter(function(file) { - return fs.statSync(path.join('components', file)).isDirectory() && - namePredicate(file); - }) - .map(function(dir) { return '/' + dir + '/'; }); -}; - -/** - * Returns a list of tensorboard web components that are inside the components - * directory. - */ -exports.tbComponents = exports.getComponents(function(name) { - var prefix = name.slice(0, 3); - return prefix == 'tf_' || prefix == 'vz_'; -}); diff --git a/tensorflow/tensorboard/gulp_tasks/vulcanize.js b/tensorflow/tensorboard/gulp_tasks/vulcanize.js deleted file mode 100644 index b8cdd80af02..00000000000 --- a/tensorflow/tensorboard/gulp_tasks/vulcanize.js +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -var gulp = require('gulp'); -var path = require('path'); -var util = require('./util'); -var vulcanize = require('gulp-vulcanize'); -var replace = require('gulp-replace'); -var rename = require('gulp-rename'); -var header = require('gulp-header'); - -var HEADER_STR = '\n\n' - -var base = path.join(__dirname, '../components'); -// List of redirects of the form path1|path2 for every tensorboard component -// in order to replace dashes with underscores. -// E.g. .../tf-tensorboard|.../tf_tensorboard -var redirects = util.tbComponents.map(function(dir) { - return path.join(base, dir.replace(/_/g, '-')) + '|' + path.join(base, dir); -}); - -var nonTBComponents = util.getComponents(function(name) { - var prefix = name.slice(0, 3); - return prefix !== 'tf_' && prefix !== 'vz_'; -}); - -module.exports = function(overwrite) { - return function() { - var suffix = overwrite ? '' : '.OPENSOURCE'; - // Vulcanize TensorBoard without external libraries. - gulp.src('components/tf_tensorboard/tf-tensorboard.html') - .pipe(vulcanize({ - inlineScripts: true, - inlineCss: true, - stripComments: true, - excludes: nonTBComponents, - redirects: redirects - })) - .pipe(header(HEADER_STR)) - .pipe(rename('tf-tensorboard.html' + suffix)) - .pipe(gulp.dest('./dist')); - } -} diff --git a/tensorflow/tensorboard/gulpfile.js b/tensorflow/tensorboard/gulpfile.js deleted file mode 100644 index 257ee0ab83d..00000000000 --- a/tensorflow/tensorboard/gulpfile.js +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -var gulp = require('gulp'); -var server = require('gulp-server-livereload'); -var minimist = require('minimist'); -var util = require('./gulp_tasks/util'); - -var options = minimist(process.argv.slice(2), { - default: { - p: 8000, // port for gulp server - h: '0.0.0.0', // host to serve on - } -}); - -function getTask(task) { - return require('./gulp_tasks/' + task); -} - - -gulp.task('compile', getTask('compile')(true)); -gulp.task('first-compile', getTask('compile')(true)); -gulp.task('compile-without-deps', getTask('compile')(false)); -gulp.task('test.onlytest', getTask('test')); -gulp.task('test', ['compile'], getTask('test')); - -gulp.task('watch', [], function() { - // Avoid watching generated .d.ts in the build (aka output) directory. - return gulp.watch( - ['components/tf_*/**/*.ts', 'components/vz_*/**/*.ts'], - {ignoreInitial: true}, ['compile']); -}); - -var httpPrefix = 'http://' + options.h + ':' + options.p + '/components'; -var proxies = util.tbComponents.map(function(component) { - return { - source: '/components' + component.replace(/_/g, '-'), - target: httpPrefix + component - }; -}); - -// Do first-compile before turning on server, to avoid spamming -// livereload info -// TODO(danmane): Disconnect this once we can get livereload to -// no longer spam. -gulp.task('server', ['first-compile'], function() { - gulp.src('.').pipe(server({ - host: options.h, - port: options.p, - livereload: { - enable: true, - // Don't livereload on .ts changes, since they aren't loaded by browser. - filter: function(filePath, cb) { cb(!(/\.ts$/.test(filePath))); }, - port: 27729 + options.p - }, - proxies: proxies, - directoryListing: true, - })); -}); - -// TODO(danmane): When testing is nicer, integrate into vulcanize task -// gulp vulcanize: Regenerate the tf-tensorboard.html.OPENSOURCE file for pre-release -gulp.task( - 'vulcanize', ['compile-without-deps'], - getTask('vulcanize')(false)); -// gulp regenerate: Regenerate the tf-tensorboard.html for interactive bazel development -gulp.task( - 'regenerate', ['compile-without-deps'], - getTask('vulcanize')(true)); - -// TODO(danmane): consider making bower install part of default task -gulp.task('default', ['watch', 'server']); - -// Clean all compiled JS files. -var cleanCompiledTypeScript = require('gulp-clean-compiled-typescript'); -gulp.task('clean', function () { - return gulp.src(['./components/**/*.ts', '!./components/**/deps.d.ts']) - .pipe(cleanCompiledTypeScript()); -}); diff --git a/tensorflow/tensorboard/http_api.md b/tensorflow/tensorboard/http_api.md index 16c2f95ae1c..c2885daf93c 100644 --- a/tensorflow/tensorboard/http_api.md +++ b/tensorflow/tensorboard/http_api.md @@ -36,42 +36,43 @@ Returns a JSON object with a key "logdir" that maps to the `logdir` argument The `logdir` argument is the path of the directory that contains events files. +## `data/plugins_listing` + +Returns a dict mapping from plugin name to a boolean indicating whether the +plugin is active. A plugin might be inactive, for instance, if it lacks relevant +data. Every plugin has a key. This route helps the frontend avoid issuing +requests to an inactive plugin - the routes of an inactive plugin do not work. + ## `data/runs` -Returns a dictionary mapping from `run name` (quoted string) to dictionaries -mapping from all available tagTypes to a list of tags of that type available for -the run. Think of this as a comprehensive index of all of the data available -from the TensorBoard server. Here is an example: +Returns an array containing the names of all the runs known to the +TensorBoard backend at this time. Each entry is a string corresponding +to a single run. + +We guarantee that as new runs are created in the log directory, they +will always appear at the end of the list returned by this route. That +is, the order of runs is persistent, and the result of this route is an +“append-only” list. + +Example response: + + ["train_run", "eval"] + +## `/data/plugin/scalars/tags` + +Returns a dictionary mapping from `run_name` (quoted string) to arrays of +`tag_name` (quoted string), where each array contains the names of all +scalar tags present in the corresponding run. Here is an example: { - "train_run": { - "histograms": ["foo_histogram", "bar_histogram"], - "compressedHistograms": ["foo_histogram", "bar_histogram"], - "scalars": ["xent", "loss", "learning_rate"], - "images": ["input"], - "audio": ["input_audio"], - "graph": true, - "firstEventTimestamp": 123456.789 - "run_metadata": ["forward prop", "inference"] - }, - "eval": { - "histograms": ["foo_histogram", "bar_histogram"], - "compressedHistograms": ["foo_histogram", "bar_histogram"], - "scalars": ["precision", "recall"], - "images": ["input"], - "audio": ["input_audio"], - "graph": false, - "run_metadata": [] - } - } + "train_run": ["xent", "loss", "learning_rate"], + "eval": ["precision", "recall"] + } -The `firstEventTimestamp` value is in seconds since the epoch. +Note that runs without any scalar tags are included as keys with value the +empty array. -Note that the same tag may be present for many runs. It is not guaranteed that -they will have the same meaning across runs. It is also not guaranteed that they -will have the same tag type across different runs. - -## '/data/scalars?run=foo&tag=bar' +## `/data/plugin/scalars/scalars?run=foo&tag=bar` Returns an array of event_accumulator.SimpleValueEvents ([wall_time, step, value]) for the given run and tag. wall_time is seconds since epoch. @@ -93,28 +94,21 @@ format: 1443857105.704628,3438,0.5427092909812927 1443857225.705133,5417,0.5457325577735901 -## '/data/scalars?[sample_count=10]' +## `/data/plugin/histograms/tags` -Without any parameters, returns a dictionary mapping from run name to a -dictionary mapping from tag name to a sampled list of scalars from that run and -tag. The values are given in the same format as when the run and tag are -specified. For example: +Returns a dictionary mapping from `run_name` (quoted string) to arrays of +`tag_name` (quoted string), where each array contains the names of all +histogram tags present in the corresponding run. Here is an example: { - "train_run": { - "my_tag": [ - [1443856985.705543, 1448, 0.7461960315704346], - [1443857105.704628, 3438, 0.5427092909812927], - [1443857225.705133, 5417, 0.5457325577735901] - ] - } + "train_run": ["foo_histogram", "bar_histogram"], + "eval": ["foo_histogram", "bar_histogram"] } -The samples are distributed uniformly over the list of values. The sample_count -parameter is optional and defaults to 10; it must be at least 2. The first and -the last value will always be sampled. +Note that runs without any histogram tags are included as keys with +value the empty array. -## '/data/histograms?run=foo&tag=bar' +## `/data/plugin/histograms/histograms?run=foo&tag=bar` Returns an array of event_accumulator.HistogramEvents ([wall_time, step, HistogramValue]) for the given run and tag. A HistogramValue is [min, max, num, @@ -141,7 +135,21 @@ Annotated Example: (note - real data is higher precision) ] ] -## '/data/compressedHistograms?run=foo&tag=bar' +## `/data/plugin/distributions/tags` + +Returns a dictionary mapping from `run_name` (quoted string) to arrays of +`tag_name` (quoted string), where each array contains the names of all +distribution tags present in the corresponding run. Here is an example: + + { + "train_run": ["foo_histogram", "bar_histogram"], + "eval": ["foo_histogram", "bar_histogram"] + } + +Note that runs without any distribution tags are included as keys with +value the empty array. + +## `/data/plugin/distributions/distributions?run=foo&tag=bar` Returns an array of event_accumulator.CompressedHistogramEvents ([wall_time, step, CompressedHistogramValues]) for the given run and tag. @@ -161,8 +169,8 @@ Annotated Example: (note - real data is higher precision) [ 1441154832.580509, # wall_time 5, # step - [ [0, -3.67], # CompressedHistogramValue for 0th percentile - [2500, -4.19], # CompressedHistogramValue for 25th percentile + [ [0, -3.67], # CompressedHistogramValue for 0th percentile + [2500, -4.19], # CompressedHistogramValue for 25th percentile [5000, 6.29], [7500, 1.64], [10000, 3.67] @@ -171,13 +179,13 @@ Annotated Example: (note - real data is higher precision) ... ] -## `/data/images?run=foo&tag=bar` +## `/data/plugin/images/images?run=foo&tag=bar` Gets a sample of ImageMetadatas for the given run and tag. Returns an array of objects containing information about available images, crucially including the query parameter that may be used to retrieve that image. -(See /individualImage for details.) +(See /data/plugin/images/individualImage for details.) For example: @@ -190,7 +198,7 @@ For example: # param for /individualImage } -## `/data/individualImage?{{query}}` +## `/data/plugin/images/individualImage?{{query}}` Retrieves an individual image. The image query should not be generated by the frontend, but instead acquired from calling the /images route (the image @@ -202,15 +210,29 @@ within a single run, as images may be removed from the sampling reservoir and replaced with other images. (See Notes for details on the reservoir sampling.) An example call to this route would look like this: -/individualImage?index=0&tagname=input%2Fimage%2F2&run=train +/data/plugin/images/individualImage?index=0&tagname=input%2Fimage%2F2&run=train -## `/audio?run=foo&tag=bar` +## `/data/plugin/images/tags` + +Returns a dictionary mapping from `run_name` (quoted string) to arrays of +`tag_name` (quoted string), where each array contains the names of all image +tags present in the corresponding run. Here is an example: + + { + "train": ["foo_image", "bar_image"], + "eval": ["foo_image", "bar_image"] + } + +Note that runs without any image tags are included as keys with value the empty +array. + +## `/data/plugin/audio/audio?run=foo&tag=bar` Gets a sample of AudioMetadatas for the given run and tag. Returns an array of objects containing information about available audio, crucially including the query parameter that may be used to retrieve that audio. -(See /individualAudio for details.) +(See /data/plugin/audio/individualAudio for details.) For example: @@ -222,7 +244,7 @@ For example: # param for /individualAudio } -## `/individualAudio?{{query}}` +## `/data/plugin/audio/individualAudio?{{query}}` Retrieves an individual audio clip. The audio query should not be generated by the frontend, but instead acquired from calling the /audio route (the audio @@ -236,11 +258,33 @@ replaced with other clips. (See Notes for details on the reservoir sampling.) An example call to this route would look like this: /individualAudio?index=0&tagname=input%2Faudio%2F2&run=train -## `/data/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` +## `/data/plugin/audio/tags` -Returns the graph definition for the given run in gzipped pbtxt format. The -graph is composed of a list of nodes, where each node is a specific TensorFlow -operation which takes as inputs other nodes (operations). +Returns a dictionary mapping from `run_name` (quoted string) to arrays of +`tag_name` (quoted string), where each array contains the names of all audio +tags present in the corresponding run. Here is an example: + + { + "train": ["foo_audio", "bar_audio"], + "eval": ["foo_audio", "bar_audio"], + } + +Note that runs without any audio tags are included as keys with value the empty +array. + +## `/data/plugin/graphs/runs` + +Returns a list of runs that have associated graphs. + +For example: + + ["train"] + +## `/data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` + +Returns the graph definition for the given run in pbtxt format. The +graph is composed of a list of nodes, where each node is a specific +TensorFlow operation which takes as inputs other nodes (operations). The query parameters `limit_attr_size` and `large_attrs_key` are optional. @@ -253,7 +297,10 @@ attributes that are too large. The value of this key (list of strings) should be used by the client in order to determine which attributes have been filtered. Must be specified if `limit_attr_size` is specified. -For the query `/graph?run=foo&limit_attr_size=1024&large_attrs_key=_too_large`, +For the query + + /data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=_too_large, + here is an example pbtxt response of a graph with 3 nodes, where the second node had two large attributes "a" and "b" that were filtered out (size > 1024): diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD new file mode 100644 index 00000000000..f1f7746ff84 --- /dev/null +++ b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD @@ -0,0 +1,56 @@ +package(default_visibility = ["//tensorflow/tensorboard:internal"]) + +licenses(["notice"]) # Apache 2.0 + +java_binary( + name = "Vulcanize", + srcs = ["Vulcanize.java"], + jvm_flags = [ + "-Xss20m", # JSCompiler needs big stacks for recursive parsing + "-XX:+UseParallelGC", # Best GC when app isn't latency sensitive + "-Djava.util.logging.SimpleFormatter.format='%1$$tY-%1$$tm-%1$$td %1$$tH:%1$$tM:%1$$tS.%1$$tL %4$$-6s %5$$s%6$$s%n'", # Less log spam + ], + visibility = ["//visibility:public"], + deps = [ + "@com_google_guava", + "@com_google_protobuf_java", + "@io_bazel_rules_closure//closure/compiler", + "@io_bazel_rules_closure//java/io/bazel/rules/closure:webpath", + "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles:build_info_java_proto", + "@io_bazel_rules_closure//java/org/jsoup/nodes", + "@org_jsoup", + ], +) + +java_binary( + name = "Zipper", + srcs = ["Zipper.java"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_guava", + "@com_google_protobuf_java", + "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles", + "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles:build_info_java_proto", + ], +) + +# These JS files are always taken into consideration by the Closure Compiler +# when vulcanizing, per vulcanize.bzl. +filegroup( + name = "jslibs", + srcs = [ + # Ordering probably matters + "@com_google_javascript_closure_compiler_externs", + "@com_google_javascript_closure_compiler_externs_polymer", + "externs.js", + "@com_google_javascript_closure_library//:closure/goog/base.js", + "@com_google_javascript_closure_library//:closure/goog/deps.js", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java new file mode 100644 index 00000000000..533907dd64d --- /dev/null +++ b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java @@ -0,0 +1,546 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.tensorflow.tensorboard.vulcanize; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Verify.verify; +import static com.google.common.base.Verify.verifyNotNull; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.base.CharMatcher; +import com.google.common.base.Joiner; +import com.google.common.base.Optional; +import com.google.common.base.Splitter; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Multimap; +import com.google.javascript.jscomp.CheckLevel; +import com.google.javascript.jscomp.CompilationLevel; +import com.google.javascript.jscomp.Compiler; +import com.google.javascript.jscomp.CompilerOptions; +import com.google.javascript.jscomp.DiagnosticGroup; +import com.google.javascript.jscomp.DiagnosticGroups; +import com.google.javascript.jscomp.DiagnosticType; +import com.google.javascript.jscomp.JSError; +import com.google.javascript.jscomp.ModuleIdentifier; +import com.google.javascript.jscomp.PropertyRenamingPolicy; +import com.google.javascript.jscomp.Result; +import com.google.javascript.jscomp.SourceFile; +import com.google.javascript.jscomp.WarningsGuard; +import com.google.protobuf.TextFormat; +import io.bazel.rules.closure.Webpath; +import io.bazel.rules.closure.webfiles.BuildInfo.Webfiles; +import io.bazel.rules.closure.webfiles.BuildInfo.WebfilesSource; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import org.jsoup.Jsoup; +import org.jsoup.nodes.Attribute; +import org.jsoup.nodes.Comment; +import org.jsoup.nodes.DataNode; +import org.jsoup.nodes.Document; +import org.jsoup.nodes.Element; +import org.jsoup.nodes.Html5Printer; +import org.jsoup.nodes.Node; +import org.jsoup.nodes.TextNode; +import org.jsoup.parser.Parser; +import org.jsoup.parser.Tag; + +/** Simple one-off solution for TensorBoard vulcanization. */ +public final class Vulcanize { + + private static final Pattern IGNORE_PATHS_PATTERN = + Pattern.compile("/(?:polymer|marked-element)/.*"); + + private static final ImmutableSet EXTRA_JSDOC_TAGS = + ImmutableSet.of("attribute", "hero", "group", "required"); + + private static final Pattern WEBPATH_PATTERN = Pattern.compile("//~~WEBPATH~~([^\n]+)"); + + private static final Parser parser = Parser.htmlParser(); + private static final Map webfiles = new HashMap<>(); + private static final Set alreadyInlined = new HashSet<>(); + private static final Set legalese = new HashSet<>(); + private static final List licenses = new ArrayList<>(); + private static final List stack = new ArrayList<>(); + private static final List externs = new ArrayList<>(); + private static final List sourcesFromJsLibraries = new ArrayList<>(); + private static final Map sourcesFromScriptTags = new LinkedHashMap<>(); + private static final Map sourceTags = new LinkedHashMap<>(); + private static final Multimap suppressions = HashMultimap.create(); + private static CompilationLevel compilationLevel; + private static Webpath outputPath; + private static Node firstCompiledScript; + private static Node licenseComment; + private static int insideDemoSnippet; + private static boolean testOnly; + + public static void main(String[] args) throws IOException { + compilationLevel = CompilationLevel.fromString(args[0]); + testOnly = args[1].equals("true"); + Webpath inputPath = Webpath.get(args[2]); + outputPath = Webpath.get(args[3]); + Path output = Paths.get(args[4]); + for (int i = 5; i < args.length; i++) { + if (args[i].endsWith(".js")) { + String code = new String(Files.readAllBytes(Paths.get(args[i])), UTF_8); + SourceFile sourceFile = SourceFile.fromCode(args[i], code); + if (code.contains("@externs")) { + externs.add(sourceFile); + } else { + sourcesFromJsLibraries.add(sourceFile); + } + continue; + } + if (!args[i].endsWith(".pbtxt")) { + continue; + } + Webfiles manifest = loadWebfilesPbtxt(Paths.get(args[i])); + for (WebfilesSource src : manifest.getSrcList()) { + webfiles.put(Webpath.get(src.getWebpath()), Paths.get(src.getPath())); + } + } + stack.add(inputPath); + Document document = parse(Files.readAllBytes(webfiles.get(inputPath))); + transform(document); + compile(); + if (licenseComment != null) { + licenseComment.attr("comment", String.format("\n%s\n", Joiner.on("\n\n").join(licenses))); + } + Files.write( + output, + Html5Printer.stringify(document).getBytes(UTF_8), + StandardOpenOption.WRITE, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING); + } + + private static void transform(Node root) throws IOException { + Node node = checkNotNull(root); + Node newNode; + while (true) { + newNode = enterNode(node); + if (node.equals(root)) { + root = newNode; + } + node = newNode; + if (node.childNodeSize() > 0) { + node = node.childNode(0); + } else { + while (true) { + newNode = leaveNode(node); + if (node.equals(root)) { + root = newNode; + } + node = newNode; + if (node.equals(root)) { + return; + } + Node next = node.nextSibling(); + if (next == null) { + if (node.parentNode() == null) { + return; + } + node = verifyNotNull(node.parentNode(), "unexpected root: %s", node); + } else { + node = next; + break; + } + } + } + } + } + + private static Node enterNode(Node node) throws IOException { + if (node.nodeName().equals("demo-snippet")) { + insideDemoSnippet++; + } + if (insideDemoSnippet > 0) { + return node; + } + if (node instanceof Element) { + if (!getAttrTransitive(node, "vulcanize-noinline").isPresent()) { + if (node.nodeName().equals("link") && node.attr("rel").equals("import")) { + // Inline HTML. + node = visitHtmlImport(node); + } else if (node.nodeName().equals("script") + && !shouldIgnoreUri(node.attr("src")) + && !node.hasAttr("jscomp-ignore")) { + node = visitScript(node); + } else if (node.nodeName().equals("link") + && node.attr("rel").equals("stylesheet") + && !node.attr("href").isEmpty() + && !shouldIgnoreUri(node.attr("href"))) { + node = visitStylesheet(node); + } + } + rootifyAttribute(node, "href"); + rootifyAttribute(node, "src"); + rootifyAttribute(node, "action"); + rootifyAttribute(node, "assetpath"); + } else if (node instanceof Comment) { + String text = ((Comment) node).getData(); + if (text.contains("@license")) { + handleLicense(text); + if (licenseComment == null) { + licenseComment = node; + } else { + node = replaceNode(node, new TextNode("", node.baseUri())); + } + } else { + node = replaceNode(node, new TextNode("", node.baseUri())); + } + } + return node; + } + + private static Node leaveNode(Node node) { + if (node instanceof Document) { + stack.remove(stack.size() - 1); + } else if (node.nodeName().equals("demo-snippet")) { + insideDemoSnippet--; + } + return node; + } + + private static Node visitHtmlImport(Node node) throws IOException { + Webpath href = me().lookup(Webpath.get(node.attr("href"))); + if (alreadyInlined.add(href)) { + stack.add(href); + Document subdocument = parse(Files.readAllBytes(getWebfile(href))); + for (Attribute attr : node.attributes()) { + subdocument.attr(attr.getKey(), attr.getValue()); + } + return replaceNode(node, subdocument); + } else { + return replaceNode(node, new TextNode("", node.baseUri())); + } + } + + private static Node visitScript(Node node) throws IOException { + Webpath path; + String script; + if (node.attr("src").isEmpty()) { + path = makeSyntheticName(".js"); + script = getInlineScriptFromNode(node); + } else { + path = me().lookup(Webpath.get(node.attr("src"))); + script = new String(Files.readAllBytes(getWebfile(path)), UTF_8); + } + if (node.attr("src").endsWith(".min.js") + || getAttrTransitive(node, "jscomp-nocompile").isPresent()) { + Node newScript = + new Element(Tag.valueOf("script"), node.baseUri(), node.attributes()) + .appendChild(new DataNode(script, node.baseUri())) + .removeAttr("src") + .removeAttr("jscomp-nocompile"); + if (firstCompiledScript != null) { + firstCompiledScript.before(newScript); + return replaceNode(node, new TextNode("", node.baseUri())); + } else { + return replaceNode(node, newScript); + } + } else { + if (firstCompiledScript == null) { + firstCompiledScript = node; + } + sourcesFromScriptTags.put(path, script); + sourceTags.put(path, node); + Optional suppress = getAttrTransitive(node, "jscomp-suppress"); + if (suppress.isPresent()) { + if (suppress.get().isEmpty()) { + suppressions.put(path, "*"); + } else { + suppressions.putAll(path, Splitter.on(' ').split(suppress.get())); + } + } + return node; + } + } + + private static Node visitStylesheet(Node node) throws IOException { + Webpath href = me().lookup(Webpath.get(node.attr("href"))); + return replaceNode( + node, + new Element(Tag.valueOf("style"), node.baseUri(), node.attributes()) + .appendChild( + new DataNode( + new String(Files.readAllBytes(getWebfile(href)), UTF_8), node.baseUri())) + .removeAttr("rel") + .removeAttr("href")); + } + + private static Optional getAttrTransitive(Node node, String attr) { + while (node != null) { + if (node.hasAttr(attr)) { + return Optional.of(node.attr(attr)); + } + node = node.parent(); + } + return Optional.absent(); + } + + private static Node replaceNode(Node oldNode, Node newNode) { + oldNode.replaceWith(newNode); + return newNode; + } + + private static Path getWebfile(Webpath path) { + return verifyNotNull(webfiles.get(path), "Bad ref: %s -> %s", me(), path); + } + + private static void compile() { + if (sourcesFromScriptTags.isEmpty()) { + return; + } + + CompilerOptions options = new CompilerOptions(); + compilationLevel.setOptionsForCompilationLevel(options); + + // Nice options. + options.setColorizeErrorOutput(true); + options.setContinueAfterErrors(true); + options.setLanguageIn(CompilerOptions.LanguageMode.ECMASCRIPT_2016); + options.setLanguageOut(CompilerOptions.LanguageMode.ECMASCRIPT5); + options.setGenerateExports(true); + options.setStrictModeInput(false); + options.setExtraAnnotationNames(EXTRA_JSDOC_TAGS); + + // So we can chop JS binary back up into the original script tags. + options.setPrintInputDelimiter(true); + options.setInputDelimiter("//~~WEBPATH~~%name%"); + + // Optimizations that are too advanced for us right now. + options.setPropertyRenaming(PropertyRenamingPolicy.OFF); + options.setCheckGlobalThisLevel(CheckLevel.OFF); + options.setRemoveUnusedPrototypeProperties(false); + options.setRemoveUnusedPrototypePropertiesInExterns(false); + options.setRemoveUnusedClassProperties(false); + + // Dependency management. + options.setClosurePass(true); + options.setManageClosureDependencies(true); + options.getDependencyOptions().setDependencyPruning(true); + options.getDependencyOptions().setDependencySorting(true); + options.getDependencyOptions().setMoocherDropping(false); + options.getDependencyOptions() + .setEntryPoints( + sourceTags + .keySet() + .stream() + .map(Webpath::toString) + .map(ModuleIdentifier::forFile) + .collect(Collectors.toList())); + + // Polymer pass. + options.setPolymerVersion(1); + + // Debug flags. + if (testOnly) { + options.setPrettyPrint(true); + options.setGeneratePseudoNames(true); + options.setExportTestFunctions(true); + } + + // Don't print warnings from " + sanitized = "<script>alert('xss')</script>" + self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) + + dangerous = textwrap.dedent("""\ + hello *you*""") + sanitized = '

hello you

' + self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) + + def testTableGeneration(self): + array2d = np.array([['one', 'two'], ['three', 'four']]) + expected_table = textwrap.dedent("""\ + + + + + + + + + + + +
onetwo
threefour
""") + self.assertEqual(text_plugin.make_table(array2d), expected_table) + + expected_table_with_headers = textwrap.dedent("""\ + + + + + + + + + + + + + + + + + +
c1c2
onetwo
threefour
""") + + actual_with_headers = text_plugin.make_table(array2d, headers=['c1', 'c2']) + self.assertEqual(actual_with_headers, expected_table_with_headers) + + array_1d = np.array(['one', 'two', 'three', 'four', 'five']) + expected_1d = textwrap.dedent("""\ + + + + + + + + + + + + + + + + + + +
one
two
three
four
five
""") + self.assertEqual(text_plugin.make_table(array_1d), expected_1d) + + expected_1d_with_headers = textwrap.dedent("""\ + + + + + + + + + + + + + + + + + + + + + + + +
X
one
two
three
four
five
""") + actual_1d_with_headers = text_plugin.make_table(array_1d, headers=['X']) + self.assertEqual(actual_1d_with_headers, expected_1d_with_headers) + + def testMakeTableExceptions(self): + # Verify that contents is being type-checked and shape-checked. + with self.assertRaises(ValueError): + text_plugin.make_table([]) + + with self.assertRaises(ValueError): + text_plugin.make_table('foo') + + with self.assertRaises(ValueError): + invalid_shape = np.full((3, 3, 3), 'nope', dtype=np.dtype('S3')) + text_plugin.make_table(invalid_shape) + + # Test headers exceptions in 2d array case. + test_array = np.full((3, 3), 'foo', dtype=np.dtype('S3')) + with self.assertRaises(ValueError): + # Headers is wrong type. + text_plugin.make_table(test_array, headers='foo') + with self.assertRaises(ValueError): + # Too many headers. + text_plugin.make_table(test_array, headers=['foo', 'bar', 'zod', 'zoink']) + with self.assertRaises(ValueError): + # headers is 2d + text_plugin.make_table(test_array, headers=test_array) + + # Also make sure the column counting logic works in the 1d array case. + test_array = np.array(['foo', 'bar', 'zod']) + with self.assertRaises(ValueError): + # Too many headers. + text_plugin.make_table(test_array, headers=test_array) + + def test_reduce_to_2d(self): + + def make_range_array(dim): + """Produce an incrementally increasing multidimensional array. + + Args: + dim: the number of dimensions for the array + + Returns: + An array of increasing integer elements, with dim dimensions and size + two in each dimension. + + Example: rangeArray(2) results in [[0,1],[2,3]]. + """ + return np.array(range(2**dim)).reshape([2] * dim) + + for i in range(2, 5): + actual = text_plugin.reduce_to_2d(make_range_array(i)) + expected = make_range_array(2) + np.testing.assert_array_equal(actual, expected) + + def test_text_array_to_html(self): + + convert = text_plugin.text_array_to_html + scalar = np.array('foo') + scalar_expected = '

foo

' + self.assertEqual(convert(scalar), scalar_expected) + + vector = np.array(['foo', 'bar']) + vector_expected = textwrap.dedent("""\ + + + + + + + + + +

foo

bar

""") + self.assertEqual(convert(vector), vector_expected) + + d2 = np.array([['foo', 'bar'], ['zoink', 'zod']]) + d2_expected = textwrap.dedent("""\ + + + + + + + + + + + +

foo

bar

zoink

zod

""") + self.assertEqual(convert(d2), d2_expected) + + d3 = np.array([[['foo', 'bar'], ['zoink', 'zod']], [['FOO', 'BAR'], + ['ZOINK', 'ZOD']]]) + + warning = text_plugin.markdown_and_sanitize(text_plugin.WARNING_TEMPLATE % + 3) + d3_expected = warning + textwrap.dedent("""\ + + + + + + + + + + + +

foo

bar

zoink

zod

""") + self.assertEqual(convert(d3), d3_expected) + + def testPluginIsActive(self): + plugin = text_plugin.TextPlugin() + multiplexer = event_multiplexer.EventMultiplexer() + plugin.get_plugin_apps(event_multiplexer.EventMultiplexer(), None) + + # The plugin is inactive because text summaries are not available. + self.assertFalse(plugin.is_active()) + + multiplexer.AddRunsFromDirectory(self.logdir) + multiplexer.Reload() + + # The plugin is active because text summaries are available. + self.assertTrue(self.plugin.is_active()) + + def testUnicode(self): + self.assertConverted(u'

Iñtërnâtiônàlizætiøn⚡💩

', + 'Iñtërnâtiônàlizætiøn⚡💩') + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/tensorboard/scripts/BUILD b/tensorflow/tensorboard/scripts/BUILD index 710191b238a..05425ee61d0 100644 --- a/tensorflow/tensorboard/scripts/BUILD +++ b/tensorflow/tensorboard/scripts/BUILD @@ -1,7 +1,7 @@ # Description: # Some useful scripts that are bundled with TensorBoard. -package(default_visibility = ["//tensorflow:internal"]) +package(default_visibility = ["//tensorflow/tensorboard:internal"]) licenses(["notice"]) # Apache 2.0 @@ -12,18 +12,19 @@ py_binary( srcs = ["generate_testdata.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python", # TODO(b/34059704): remove when fixed - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:logging_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + "//tensorflow:tensorflow_py", "//third_party/py/numpy", "@six_archive//:six", ], ) +py_binary( + name = "execrooter", + srcs = ["execrooter.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) + filegroup( name = "all_files", srcs = glob(["*"]), diff --git a/tensorflow/tensorboard/scripts/execrooter.py b/tensorflow/tensorboard/scripts/execrooter.py new file mode 100644 index 00000000000..65569b91512 --- /dev/null +++ b/tensorflow/tensorboard/scripts/execrooter.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility for running programs in a symlinked execroot.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os +import shutil +import subprocess +import sys +import tempfile + + +def run(inputs, program, outputs): + """Creates temp symlink tree, runs program, and copies back outputs. + + Args: + inputs: List of fake paths to real paths, which are used for symlink tree. + program: List containing real path of program and its arguments. The + execroot directory will be appended as the last argument. + outputs: List of fake outputted paths to copy back to real paths. + Returns: + 0 if succeeded or nonzero if failed. + """ + root = tempfile.mkdtemp() + try: + cwd = os.getcwd() + for fake, real in inputs: + parent = os.path.join(root, os.path.dirname(fake)) + if not os.path.exists(parent): + os.makedirs(parent) + os.symlink(os.path.join(cwd, real), os.path.join(root, fake)) + if subprocess.call(program + [root]) != 0: + return 1 + for fake, real in outputs: + shutil.copyfile(os.path.join(root, fake), real) + return 0 + finally: + shutil.rmtree(root) + + +def main(args): + """Invokes run function using a JSON file config. + + Args: + args: CLI args, which can be a JSON file containing an object whose + attributes are the parameters to the run function. If multiple JSON + files are passed, their contents are concatenated. + Returns: + 0 if succeeded or nonzero if failed. + Raises: + Exception: If input data is missing. + """ + if not args: + raise Exception('Please specify at least one JSON config path') + inputs = [] + program = [] + outputs = [] + for arg in args: + with open(arg) as fd: + config = json.load(fd) + inputs.extend(config.get('inputs', [])) + program.extend(config.get('program', [])) + outputs.extend(config.get('outputs', [])) + if not program: + raise Exception('Please specify a program') + return run(inputs, program, outputs) + + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) diff --git a/tensorflow/tensorboard/scripts/generate_testdata.py b/tensorflow/tensorboard/scripts/generate_testdata.py index f89ab690ba3..f191d16a82d 100644 --- a/tensorflow/tensorboard/scripts/generate_testdata.py +++ b/tensorflow/tensorboard/scripts/generate_testdata.py @@ -28,20 +28,13 @@ import shutil import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.core.framework import graph_pb2 -from tensorflow.core.framework import summary_pb2 -from tensorflow.core.util import event_pb2 -from tensorflow.python.client import session as session_lib -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.platform import app -from tensorflow.python.platform import flags -from tensorflow.python.summary.writer import writer as writer_lib +import tensorflow as tf -tf.flags.DEFINE_string("target", None, """The directoy where serialized data + +tf.flags.DEFINE_string("target", None, """The directory where serialized data will be written""") -flags.DEFINE_boolean("overwrite", False, """Whether to remove and overwrite +tf.flags.DEFINE_boolean("overwrite", False, """Whether to remove and overwrite TARGET if it already exists.""") FLAGS = tf.flags.FLAGS @@ -76,7 +69,7 @@ def _MakeHistogram(values): bucket_limit = [lc[0] for lc in limit_counts] bucket = [lc[1] for lc in limit_counts] sum_sq = sum(v * v for v in values) - return summary_pb2.HistogramProto( + return tf.HistogramProto( min=min(values), max=max(values), num=len(values), @@ -92,9 +85,9 @@ def WriteScalarSeries(writer, tag, f, n=5): wall_time = _start_time for i in xrange(n): v = f(i) - value = summary_pb2.Summary.Value(tag=tag, simple_value=v) - summary = summary_pb2.Summary(value=[value]) - event = event_pb2.Event(wall_time=wall_time, step=step, summary=summary) + value = tf.Summary.Value(tag=tag, simple_value=v) + summary = tf.Summary(value=[value]) + event = tf.Event(wall_time=wall_time, step=step, summary=summary) writer.add_event(event) step += 1 wall_time += 10 @@ -107,10 +100,8 @@ def WriteHistogramSeries(writer, tag, mu_sigma_tuples, n=20): for [mean, stddev] in mu_sigma_tuples: data = [random.normalvariate(mean, stddev) for _ in xrange(n)] histo = _MakeHistogram(data) - summary = summary_pb2.Summary( - value=[summary_pb2.Summary.Value( - tag=tag, histo=histo)]) - event = event_pb2.Event(wall_time=wall_time, step=step, summary=summary) + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)]) + event = tf.Event(wall_time=wall_time, step=step, summary=summary) writer.add_event(event) step += 10 wall_time += 100 @@ -119,9 +110,9 @@ def WriteHistogramSeries(writer, tag, mu_sigma_tuples, n=20): def WriteImageSeries(writer, tag, n_images=1): """Write a few dummy images to writer.""" step = 0 - session = session_lib.Session() - p = array_ops.placeholder("uint8", (1, 4, 4, 3)) - s = logging_ops.image_summary(tag, p) + session = tf.Session() + p = tf.placeholder("uint8", (1, 4, 4, 3)) + s = tf.summary.image(tag, p) for _ in xrange(n_images): im = np.random.random_integers(0, 255, (1, 4, 4, 3)) summ = session.run(s, feed_dict={p: im}) @@ -133,18 +124,18 @@ def WriteImageSeries(writer, tag, n_images=1): def WriteAudioSeries(writer, tag, n_audio=1): """Write a few dummy audio clips to writer.""" step = 0 - session = session_lib.Session() + session = tf.Session() min_frequency_hz = 440 max_frequency_hz = 880 sample_rate = 4000 - duration_frames = sample_rate * 0.5 # 0.5 seconds. + duration_frames = sample_rate // 2 # 0.5 seconds. frequencies_per_run = 1 num_channels = 2 - p = array_ops.placeholder("float32", (frequencies_per_run, duration_frames, - num_channels)) - s = logging_ops.audio_summary(tag, p, sample_rate) + p = tf.placeholder("float32", (frequencies_per_run, duration_frames, + num_channels)) + s = tf.summary.audio(tag, p, sample_rate) for _ in xrange(n_audio): # Generate a different frequency for each channel to show stereo works. @@ -170,7 +161,7 @@ def GenerateTestData(path): """Generates the test data directory.""" run1_path = os.path.join(path, "run1") os.makedirs(run1_path) - writer1 = writer_lib.FileWriter(run1_path) + writer1 = tf.summary.FileWriter(run1_path) WriteScalarSeries(writer1, "foo/square", lambda x: x * x) WriteScalarSeries(writer1, "bar/square", lambda x: x * x) WriteScalarSeries(writer1, "foo/sin", math.sin) @@ -183,7 +174,7 @@ def GenerateTestData(path): run2_path = os.path.join(path, "run2") os.makedirs(run2_path) - writer2 = writer_lib.FileWriter(run2_path) + writer2 = tf.summary.FileWriter(run2_path) WriteScalarSeries(writer2, "foo/square", lambda x: x * x * 2) WriteScalarSeries(writer2, "bar/square", lambda x: x * x * 3) WriteScalarSeries(writer2, "foo/cos", lambda x: math.cos(x) * 2) @@ -194,7 +185,7 @@ def GenerateTestData(path): WriteImageSeries(writer2, "im1") WriteAudioSeries(writer2, "au2") - graph_def = graph_pb2.GraphDef() + graph_def = tf.GraphDef() node1 = graph_def.node.add() node1.name = "a" node1.op = "matmul" @@ -231,4 +222,4 @@ def main(unused_argv=None): if __name__ == "__main__": - app.run() + tf.app.run() diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/tensorboard.py deleted file mode 100644 index 87f4f9fe511..00000000000 --- a/tensorflow/tensorboard/tensorboard.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Serve TensorFlow summary data to a web frontend. - -This is a simple web server to proxy data from the event_loader to the web, and -serve static web files. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import socket -from werkzeug import serving - -from tensorflow.python.platform import app -from tensorflow.python.platform import flags -from tensorflow.python.platform import resource_loader -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.summary import event_file_inspector as efi -from tensorflow.python.summary import event_multiplexer -from tensorflow.tensorboard.backend import application -from tensorflow.tensorboard.plugins.debugger import plugin as debugger_plugin -from tensorflow.tensorboard.plugins.projector import plugin as projector_plugin - -flags.DEFINE_string('logdir', '', """logdir specifies the directory where -TensorBoard will look to find TensorFlow event files that it can display. -TensorBoard will recursively walk the directory structure rooted at logdir, -looking for .*tfevents.* files. - -You may also pass a comma separated list of log directories, and TensorBoard -will watch each directory. You can also assign names to individual log -directories by putting a colon between the name and the path, as in - -tensorboard --logdir=name1:/path/to/logs/1,name2:/path/to/logs/2 -""") - -flags.DEFINE_boolean( - 'insecure_debug_mode', False, 'Whether to run the app in debug mode. ' - 'This increases log verbosity, and enables debugging on server exceptions.') - -flags.DEFINE_string('host', '0.0.0.0', 'What host to listen to. Defaults to ' - 'serving on 0.0.0.0, set to 127.0.0.1 (localhost) to' - 'disable remote access (also quiets security warnings).') - -flags.DEFINE_boolean('inspect', False, """Use this flag to print out a digest -of your event files to the command line, when no data is shown on TensorBoard or -the data shown looks weird. - -Example usages: -tensorboard --inspect --event_file=myevents.out -tensorboard --inspect --event_file=myevents.out --tag=loss -tensorboard --inspect --logdir=mylogdir -tensorboard --inspect --logdir=mylogdir --tag=loss - -See tensorflow/python/summary/event_file_inspector.py for more info and -detailed usage. -""") -flags.DEFINE_string( - 'tag', '', - 'The particular tag to query for. Only used if --inspect is present') -flags.DEFINE_string( - 'event_file', '', - 'The particular event file to query for. Only used if --inspect is present ' - 'and --logdir is not specified.') - -flags.DEFINE_integer('port', 6006, 'What port to serve TensorBoard on.') - -flags.DEFINE_boolean('purge_orphaned_data', True, 'Whether to purge data that ' - 'may have been orphaned due to TensorBoard restarts. ' - 'Disabling purge_orphaned_data can be used to debug data ' - 'disappearance.') - -flags.DEFINE_integer('reload_interval', 60, 'How often the backend should load ' - 'more data.') - -FLAGS = flags.FLAGS - - -class Server(object): - """A simple WSGI-compliant http server that can serve TensorBoard.""" - - def get_tag(self): - """Read the TensorBoard TAG number, and return it or an empty string.""" - try: - tag = resource_loader.load_resource('tensorboard/TAG').strip() - logging.info('TensorBoard is tag: %s', tag) - return tag - except IOError: - logging.info('Unable to read TensorBoard tag') - return '' - - def create_app(self): - """Creates a WSGI-compliant app than can handle TensorBoard requests. - - Returns: - (function) A complete WSGI application that handles TensorBoard requests. - """ - - logdir = os.path.expanduser(FLAGS.logdir) - if not logdir: - msg = ('A logdir must be specified. Run `tensorboard --help` for ' - 'details and examples.') - logging.error(msg) - print(msg) - return -1 - - multiplexer = event_multiplexer.EventMultiplexer( - size_guidance=application.DEFAULT_SIZE_GUIDANCE, - purge_orphaned_data=FLAGS.purge_orphaned_data) - plugins = { - debugger_plugin.PLUGIN_PREFIX_ROUTE: - debugger_plugin.DebuggerPlugin(), - projector_plugin.PLUGIN_PREFIX_ROUTE: - projector_plugin.ProjectorPlugin(), - } - return application.TensorBoardWSGIApp( - logdir, - plugins, - multiplexer, - reload_interval=FLAGS.reload_interval) - - def serve(self): - """Starts a WSGI server that serves the TensorBoard app.""" - - tb_app = self.create_app() - logging.info('Starting TensorBoard in directory %s', os.getcwd()) - debug = FLAGS.insecure_debug_mode - if debug: - logging.set_verbosity(logging.DEBUG) - logging.warning('TensorBoard is in debug mode. This is NOT SECURE.') - - print('Starting TensorBoard %s on port %d' % (self.get_tag(), FLAGS.port)) - if FLAGS.host == '0.0.0.0': - try: - host = socket.gethostbyname(socket.gethostname()) - print('(You can navigate to http://%s:%d)' % (host, FLAGS.port)) - except socket.gaierror: - pass - else: - print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port)) - - try: - serving.run_simple( - FLAGS.host, - FLAGS.port, - tb_app, - threaded=True, - use_reloader=debug, - use_evalex=debug, - use_debugger=debug) - except socket.error: - if FLAGS.port == 0: - msg = 'Unable to find any open ports.' - logging.error(msg) - print(msg) - return -2 - else: - msg = 'Tried to connect to port %d, but address is in use.' % FLAGS.port - logging.error(msg) - print(msg) - return -3 - - -def main(unused_argv=None): - if FLAGS.inspect: - logging.info('Not bringing up TensorBoard, but inspecting event files.') - event_file = os.path.expanduser(FLAGS.event_file) - efi.inspect(FLAGS.logdir, event_file, FLAGS.tag) - return 0 - - Server().serve() - - -if __name__ == '__main__': - app.run() diff --git a/tensorflow/tensorboard/tsconfig.json b/tensorflow/tensorboard/tsconfig.json deleted file mode 100644 index ac69c30533f..00000000000 --- a/tensorflow/tensorboard/tsconfig.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "compilerOptions": { - "noImplicitAny": false, - "noEmitOnError": true, - "target": "ES5", - "module": "commonjs" - }, - "compileOnSave": false, - "exclude": [ - "node_modules", - "typings/main.d.ts", - "typings/main", - "lib", - "components/**/deps.d.ts" - ] -} diff --git a/tensorflow/tensorboard/tslint.json b/tensorflow/tensorboard/tslint.json deleted file mode 100644 index 2a5d995e710..00000000000 --- a/tensorflow/tensorboard/tslint.json +++ /dev/null @@ -1,64 +0,0 @@ -{ - "rules": { - "class-name": true, - "comment-format": [true, "check-space"], - "curly": true, - "eofline": true, - "forin": true, - "jsdoc-format": true, - "label-position": true, - "label-undefined": true, - "max-line-length": [true, 80], - "member-ordering": [false, "variables-before-functions"], - "no-arg": true, - "no-consecutive-blank-lines": true, - "no-console": [true, - "log", - "debug", - "info", - "time", - "timeEnd", - "trace", - "warn" - ], - "no-construct": true, - "no-constructor-vars": true, - "no-debugger": true, - "no-duplicate-key": true, - "no-duplicate-variable": true, - "no-empty": true, - "no-eval": true, - "no-trailing-whitespace": true, - "no-unreachable": true, - "no-unused-expression": true, - "no-unused-variable": false, - "no-use-before-declare": false, - "one-line": [true, - "check-catch", - "check-else", - "check-open-brace", - "check-whitespace" - ], - "quotemark": [true, - "single" - ], - "radix": true, - "semicolon": [true, "always"], - "triple-equals": [true, - "allow-null-check" - ], - "typedef-whitespace": [true, { - "call-signature": "nospace", - "index-signature": "nospace", - "parameter": "nospace", - "property-declaration": "nospace", - "variable-declaration": "nospace" - }], - "whitespace": [true, - "check-branch", - "check-decl", - "check-operator", - "check-type" - ] - } -} diff --git a/tensorflow/tensorboard/typings.json b/tensorflow/tensorboard/typings.json deleted file mode 100644 index c36aa2fb9cc..00000000000 --- a/tensorflow/tensorboard/typings.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "name": "tensorflow-vis", - "dependencies": {}, - "globalDependencies": { - "chai": "registry:dt/chai#3.4.0+20160317120654", - "d3": "registry:dt/d3#0.0.0+20160514171929", - "es6-promise": "registry:dt/es6-promise#0.0.0+20160423074304", - "lodash": "registry:dt/lodash#3.10.0+20160330154726", - "mocha": "registry:dt/mocha#2.2.5+20160317120654", - "polymer": "registry:dt/polymer#1.1.6+20160922133320", - "sinon": "registry:dt/sinon#1.16.0+20160517064723", - "three": "registry:dt/three#0.0.0+20160802154944", - "webcomponents.js": "registry:dt/webcomponents.js#0.6.0+20160728153134" - } -} diff --git a/tensorflow/tensorboard/wct.conf.json b/tensorflow/tensorboard/wct.conf.json deleted file mode 100644 index 519218ce418..00000000000 --- a/tensorflow/tensorboard/wct.conf.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "suites": [ - "components/tf_*/test", - "components/vz_*/test" - ], - "plugins": ["local"] -} diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index a82bcfee611..d01342827dc 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1,136 +1,178 @@ # -*- Python -*- + # Given a source file, generate a test name. # i.e. "common_runtime/direct_session_test.cc" becomes # "common_runtime_direct_session_test" def src_to_test_name(src): return src.replace("/", "_").split(".")[0] + # Return the options to use for a C++ library or binary build. # Uses the ":optmode" config_setting to pick the options. load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", "tf_sycl_tests_tags", - "tf_additional_xla_deps_py", -) -load( - "@local_config_cuda//cuda:build_defs.bzl", - "if_cuda", - "cuda_default_copts" -) + "tf_additional_xla_deps_py",) +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda", "cuda_default_copts") load( "//third_party/mkl:build_defs.bzl", - "if_mkl", -) + "if_mkl",) + + +def full_path(relative_paths): + return [PACKAGE_NAME + "/" + relative for relative in relative_paths] # List of proto files for android builds def tf_android_core_proto_sources(core_proto_sources_relative): - return ["//tensorflow/core:" + p - for p in core_proto_sources_relative] + return [ + "//tensorflow/core:" + p for p in core_proto_sources_relative + ] + # Returns the list of pb.h and proto.h headers that are generated for # tf_android_core_proto_sources(). def tf_android_core_proto_headers(core_proto_sources_relative): - return (["//tensorflow/core/" + p.replace(".proto", ".pb.h") - for p in core_proto_sources_relative] + - ["//tensorflow/core/" + p.replace(".proto", ".proto.h") - for p in core_proto_sources_relative]) + return ([ + "//tensorflow/core/" + p.replace(".proto", ".pb.h") + for p in core_proto_sources_relative + ] + [ + "//tensorflow/core/" + p.replace(".proto", ".proto.h") + for p in core_proto_sources_relative + ]) + + +# Sanitize a dependency so that it works correctly from code that includes +# TensorFlow as a submodule. +def clean_dep(dep): + return str(Label(dep)) + + +def if_android_x86(a): + return select({ + clean_dep("//tensorflow:android_x86"): a, + clean_dep("//tensorflow:android_x86_64"): a, + "//conditions:default": [], + }) + def if_android_arm(a): return select({ - "//tensorflow:android_arm": a, + clean_dep("//tensorflow:android_arm"): a, "//conditions:default": [], }) + def if_android_arm64(a): return select({ - "//tensorflow:android_arm64": a, + clean_dep("//tensorflow:android_arm64"): a, "//conditions:default": [], }) + def if_not_android(a): return select({ - "//tensorflow:android": [], + clean_dep("//tensorflow:android"): [], "//conditions:default": a, }) + def if_android(a): return select({ - "//tensorflow:android": a, + clean_dep("//tensorflow:android"): a, "//conditions:default": [], }) + def if_ios(a): return select({ - "//tensorflow:ios": a, + clean_dep("//tensorflow:ios"): a, "//conditions:default": [], }) + def if_mobile(a): return select({ - "//tensorflow:android": a, - "//tensorflow:ios": a, + clean_dep("//tensorflow:android"): a, + clean_dep("//tensorflow:ios"): a, "//conditions:default": [], }) + def if_not_mobile(a): return select({ - "//tensorflow:android": [], - "//tensorflow:ios": [], + clean_dep("//tensorflow:android"): [], + clean_dep("//tensorflow:ios"): [], "//conditions:default": a, }) + def if_not_windows(a): return select({ - "//tensorflow:windows": [], + clean_dep("//tensorflow:windows"): [], + clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": a, }) + def if_x86(a): return select({ - "//tensorflow:linux_x86_64": a, - "//tensorflow:windows": a, + clean_dep("//tensorflow:linux_x86_64"): a, + clean_dep("//tensorflow:windows"): a, + clean_dep("//tensorflow:windows_msvc"): a, "//conditions:default": [], }) +def if_darwin(a): + return select({ + clean_dep("//tensorflow:darwin"): a, + "//conditions:default": [], + }) + +WIN_COPTS = [ + "/DLANG_CXX11", + "/D__VERSION__=\\\"MSVC\\\"", + "/DPLATFORM_WINDOWS", + "/DTF_COMPILE_LIBRARY", + "/DEIGEN_HAS_C99_MATH", + "/DTENSORFLOW_USE_EIGEN_THREADPOOL", +] + # LINT.IfChange def tf_copts(): - return (["-DEIGEN_AVOID_STL_ARRAY", - "-Iexternal/gemmlowp", - "-Wno-sign-compare", - "-fno-exceptions",] + - if_cuda(["-DGOOGLE_CUDA=1"]) + - if_mkl(["-DINTEL_MKL=1"]) + - if_android_arm(["-mfpu=neon"]) + - if_x86(["-msse4.1"]) + - select({ - "//tensorflow:android": [ - "-std=c++11", - "-DTF_LEAN_BINARY", - "-O2", - ], - "//tensorflow:darwin": [], - "//tensorflow:windows": [ - "/DLANG_CXX11", - "/D__VERSION__=\\\"MSVC\\\"", - "/DPLATFORM_WINDOWS", - "/DEIGEN_HAS_C99_MATH", - "/DTENSORFLOW_USE_EIGEN_THREADPOOL", - "/DEIGEN_VECTORIZE_SSE3", # To flush denormals without __SSE3__ set. - ], - "//tensorflow:ios": ["-std=c++11"], - "//conditions:default": ["-pthread"]})) + return ([ + "-DEIGEN_AVOID_STL_ARRAY", + "-Iexternal/gemmlowp", + "-Wno-sign-compare", + "-fno-exceptions", + ] + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm( + ["-mfpu=neon"]) + if_x86(["-msse3"]) + select({ + clean_dep("//tensorflow:android"): [ + "-std=c++11", + "-DTF_LEAN_BINARY", + "-O2", + ], + clean_dep("//tensorflow:darwin"): [], + clean_dep("//tensorflow:windows"): WIN_COPTS, + clean_dep("//tensorflow:windows_msvc"): WIN_COPTS, + clean_dep("//tensorflow:ios"): ["-std=c++11"], + "//conditions:default": ["-pthread"] + })) + def tf_opts_nortti_if_android(): return if_android([ "-fno-rtti", "-DGOOGLE_PROTOBUF_NO_RTTI", "-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER", - ]) + ]) + if_android_x86(["-msse4.1"]) + + # LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt) + # Given a list of "op_lib_names" (a list of files in the ops directory # without their .cc extensions), generate a library for that file. def tf_gen_op_libs(op_lib_names, deps=None): @@ -139,16 +181,20 @@ def tf_gen_op_libs(op_lib_names, deps=None): if not deps: deps = [] for n in op_lib_names: - native.cc_library(name=n + "_op_lib", - copts=tf_copts(), - srcs=["ops/" + n + ".cc"], - deps=deps + ["//tensorflow/core:framework"], - visibility=["//visibility:public"], - alwayslink=1, - linkstatic=1,) + native.cc_library( + name=n + "_op_lib", + copts=tf_copts(), + srcs=["ops/" + n + ".cc"], + deps=deps + [clean_dep("//tensorflow/core:framework")], + visibility=["//visibility:public"], + alwayslink=1, + linkstatic=1,) -def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="", - op_gen="//tensorflow/cc:cc_op_gen_main", + +def tf_gen_op_wrapper_cc(name, + out_ops_file, + pkg="", + op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"), deps=None, override_file=None, include_internal_ops=0): @@ -157,12 +203,11 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="", if deps == None: deps = [pkg + ":" + name + "_op_lib"] native.cc_binary( - name = tool, - copts = tf_copts(), - linkopts = ["-lm"], - linkstatic = 1, # Faster to link this one-time-use binary dynamically - deps = [op_gen] + deps - ) + name=tool, + copts=tf_copts(), + linkopts=["-lm"], + linkstatic=1, # Faster to link this one-time-use binary dynamically + deps=[op_gen] + deps) if override_file == None: srcs = [] @@ -172,14 +217,17 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="", override_arg = "$(location " + override_file + ")" native.genrule( name=name + "_genrule", - outs=[out_ops_file + ".h", out_ops_file + ".cc", - out_ops_file + "_internal.h", out_ops_file + "_internal.cc"], + outs=[ + out_ops_file + ".h", out_ops_file + ".cc", + out_ops_file + "_internal.h", out_ops_file + "_internal.cc" + ], srcs=srcs, tools=[":" + tool], cmd=("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " + "$(location :" + out_ops_file + ".cc) " + override_arg + " " + str(include_internal_ops))) + # Given a list of "op_lib_names" (a list of files in the ops directory # without their .cc extensions), generate individual C++ .cc and .h # files for each of the ops files mentioned, and then generate a @@ -206,18 +254,18 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="", # hdrs = [ "ops/array_ops_internal.h", # "ops/math_ops_internal.h" ], # deps = [ ... ]) -# TODO(josh11b): Cleaner approach for hidden ops. +# TODO(joshl): Cleaner approach for hidden ops. def tf_gen_op_wrappers_cc(name, op_lib_names=[], other_srcs=[], other_hdrs=[], pkg="", deps=[ - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/cc:const_op", + clean_dep("//tensorflow/cc:ops"), + clean_dep("//tensorflow/cc:scope"), + clean_dep("//tensorflow/cc:const_op"), ], - op_gen="//tensorflow/cc:cc_op_gen_main", + op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"), override_file=None, include_internal_ops=0, visibility=None): @@ -227,59 +275,72 @@ def tf_gen_op_wrappers_cc(name, internalhdrs = [] for n in op_lib_names: tf_gen_op_wrapper_cc( - n, "ops/" + n, pkg=pkg, op_gen=op_gen, override_file=override_file, + n, + "ops/" + n, + pkg=pkg, + op_gen=op_gen, + override_file=override_file, include_internal_ops=include_internal_ops) subsrcs += ["ops/" + n + ".cc"] subhdrs += ["ops/" + n + ".h"] internalsrcs += ["ops/" + n + "_internal.cc"] internalhdrs += ["ops/" + n + "_internal.h"] - native.cc_library(name=name, - srcs=subsrcs, - hdrs=subhdrs, - deps=deps + if_not_android([ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ]) + if_android([ - "//tensorflow/core:android_tensorflow_lib", - ]), - copts=tf_copts(), - alwayslink=1, - visibility=visibility) - native.cc_library(name=name + "_internal", - srcs=internalsrcs, - hdrs=internalhdrs, - deps=deps + if_not_android([ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ]) + if_android([ - "//tensorflow/core:android_tensorflow_lib", - ]), - copts=tf_copts(), - alwayslink=1, - visibility=["//tensorflow:internal"]) + native.cc_library( + name=name, + srcs=subsrcs, + hdrs=subhdrs, + deps=deps + if_not_android([ + clean_dep("//tensorflow/core:core_cpu"), + clean_dep("//tensorflow/core:framework"), + clean_dep("//tensorflow/core:lib"), + clean_dep("//tensorflow/core:protos_all_cc"), + ]) + if_android([ + clean_dep("//tensorflow/core:android_tensorflow_lib"), + ]), + copts=tf_copts(), + alwayslink=1, + visibility=visibility) + native.cc_library( + name=name + "_internal", + srcs=internalsrcs, + hdrs=internalhdrs, + deps=deps + if_not_android([ + clean_dep("//tensorflow/core:core_cpu"), + clean_dep("//tensorflow/core:framework"), + clean_dep("//tensorflow/core:lib"), + clean_dep("//tensorflow/core:protos_all_cc"), + ]) + if_android([ + clean_dep("//tensorflow/core:android_tensorflow_lib"), + ]), + copts=tf_copts(), + alwayslink=1, + visibility=[clean_dep("//tensorflow:internal")]) + # Invoke this rule in .../tensorflow/python to build the wrapper library. -def tf_gen_op_wrapper_py(name, out=None, hidden=None, visibility=None, deps=[], - require_shape_functions=False, hidden_file=None, +def tf_gen_op_wrapper_py(name, + out=None, + hidden=None, + visibility=None, + deps=[], + require_shape_functions=False, + hidden_file=None, generated_target_name=None): # Construct a cc_binary containing the specified ops. tool_name = "gen_" + name + "_py_wrappers_cc" if not deps: - deps = ["//tensorflow/core:" + name + "_op_lib"] + deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))] native.cc_binary( - name = tool_name, - linkopts = ["-lm"], - copts = tf_copts(), - linkstatic = 1, # Faster to link this one-time-use binary dynamically - deps = (["//tensorflow/core:framework", - "//tensorflow/python:python_op_gen_main"] + deps), - visibility = ["//tensorflow:internal"], - ) + name=tool_name, + linkopts=["-lm"], + copts=tf_copts(), + linkstatic=1, # Faster to link this one-time-use binary dynamically + deps=([ + clean_dep("//tensorflow/core:framework"), + clean_dep("//tensorflow/python:python_op_gen_main") + ] + deps), + visibility=[clean_dep("//tensorflow:internal")],) # Invoke the previous cc_binary to generate a python file. if not out: @@ -291,8 +352,8 @@ def tf_gen_op_wrapper_py(name, out=None, hidden=None, visibility=None, deps=[], name=name + "_pygenrule", outs=[out], tools=[tool_name], - cmd=("$(location " + tool_name + ") " + ",".join(hidden) - + " " + ("1" if require_shape_functions else "0") + " > $@")) + cmd=("$(location " + tool_name + ") " + ",".join(hidden) + " " + + ("1" if require_shape_functions else "0") + " > $@")) elif hidden_file: # `hidden_file` is file containing a list of op names to be hidden in the # generated module. @@ -301,77 +362,143 @@ def tf_gen_op_wrapper_py(name, out=None, hidden=None, visibility=None, deps=[], outs=[out], srcs=[hidden_file], tools=[tool_name], - cmd=("$(location " + tool_name + ") @$(location " - + hidden_file + ") " + ("1" if require_shape_functions else "0") - + " > $@")) + cmd=("$(location " + tool_name + ") @$(location " + hidden_file + ") " + + ("1" if require_shape_functions else "0") + " > $@")) else: # No ops should be hidden in the generated module. native.genrule( name=name + "_pygenrule", outs=[out], tools=[tool_name], - cmd=("$(location " + tool_name + ") " - + ("1" if require_shape_functions else "0") + " > $@")) + cmd=("$(location " + tool_name + ") " + + ("1" if require_shape_functions else "0") + " > $@")) # Make a py_library out of the generated python file. if not generated_target_name: generated_target_name = name - native.py_library(name=generated_target_name, - srcs=[out], - srcs_version="PY2AND3", - visibility=visibility, - deps=[ - "//tensorflow/python:framework_for_generated_wrappers", - ],) + native.py_library( + name=generated_target_name, + srcs=[out], + srcs_version="PY2AND3", + visibility=visibility, + deps=[ + clean_dep("//tensorflow/python:framework_for_generated_wrappers_v2"), + ],) + # Define a bazel macro that creates cc_test for tensorflow. # TODO(opensource): we need to enable this to work around the hidden symbol # __cudaRegisterFatBinary error. Need more investigations. -def tf_cc_test(name, srcs, deps, linkstatic=0, tags=[], data=[], size="medium", - suffix="", args=None, linkopts=[]): - native.cc_test(name="%s%s" % (name, suffix), - srcs=srcs, - size=size, - args=args, - copts=tf_copts(), - data=data, - deps=deps, - linkopts=["-lpthread", "-lm"] + linkopts, - linkstatic=linkstatic, - tags=tags) +def tf_cc_test(name, + srcs, + deps, + linkstatic=0, + tags=[], + data=[], + size="medium", + suffix="", + args=None, + linkopts=[]): + native.cc_test( + name="%s%s" % (name, suffix), + srcs=srcs, + size=size, + args=args, + copts=tf_copts(), + data=data, + deps=deps, + linkopts=["-lpthread", "-lm"] + linkopts, + linkstatic=linkstatic, + tags=tags) + # Part of the testing workflow requires a distinguishable name for the build # rules that involve a GPU, even if otherwise identical to the base rule. -def tf_cc_test_gpu(name, srcs, deps, linkstatic=0, tags=[], data=[], - size="medium", suffix="", args=None): - tf_cc_test(name, srcs, deps, linkstatic=linkstatic, tags=tags, data=data, - size=size, suffix=suffix, args=args) +def tf_cc_test_gpu(name, + srcs, + deps, + linkstatic=0, + tags=[], + data=[], + size="medium", + suffix="", + args=None): + tf_cc_test( + name, + srcs, + deps, + linkstatic=linkstatic, + tags=tags, + data=data, + size=size, + suffix=suffix, + args=args) -def tf_cuda_cc_test(name, srcs=[], deps=[], tags=[], data=[], size="medium", - linkstatic=0, args=[], linkopts=[]): - tf_cc_test(name=name, - srcs=srcs, - deps=deps, - tags=tags + ["manual"], - data=data, - size=size, - linkstatic=linkstatic, - linkopts=linkopts, - args=args) - tf_cc_test(name=name, - srcs=srcs, - suffix="_gpu", - deps=deps + if_cuda(["//tensorflow/core:gpu_runtime"]), - linkstatic=if_cuda(1, 0), - tags=tags + tf_cuda_tests_tags(), - data=data, - size=size, - linkopts=linkopts, - args=args) + +def tf_cuda_cc_test(name, + srcs=[], + deps=[], + tags=[], + data=[], + size="medium", + linkstatic=0, + args=[], + linkopts=[]): + tf_cc_test( + name=name, + srcs=srcs, + deps=deps, + tags=tags + ["manual"], + data=data, + size=size, + linkstatic=linkstatic, + linkopts=linkopts, + args=args) + tf_cc_test( + name=name, + srcs=srcs, + suffix="_gpu", + deps=deps + if_cuda([clean_dep("//tensorflow/core:gpu_runtime")]), + linkstatic=if_cuda(1, 0), + tags=tags + tf_cuda_tests_tags(), + data=data, + size=size, + linkopts=linkopts, + args=args) + +def tf_cuda_only_cc_test(name, + srcs=[], + deps=[], + tags=[], + data=[], + size="medium", + linkstatic=0, + args=[], + linkopts=[]): + native.cc_test( + name="%s%s" % (name, "_gpu"), + srcs=srcs, + size=size, + args=args, + copts= _cuda_copts() + tf_copts(), + data=data, + deps=deps + if_cuda([ + clean_dep("//tensorflow/core:cuda"), + clean_dep("//tensorflow/core:gpu_lib"), + ]), + linkopts=["-lpthread", "-lm"] + linkopts, + linkstatic=linkstatic, + tags=tags) # Create a cc_test for each of the tensorflow tests listed in "tests" -def tf_cc_tests(srcs, deps, name='', linkstatic=0, tags=[], size="medium", - args=None, linkopts=[]): +def tf_cc_tests(srcs, + deps, + name="", + linkstatic=0, + tags=[], + size="medium", + args=None, + linkopts=[]): for src in srcs: tf_cc_test( name=src_to_test_name(src), @@ -383,17 +510,35 @@ def tf_cc_tests(srcs, deps, name='', linkstatic=0, tags=[], size="medium", args=args, linkopts=linkopts) -def tf_cc_test_mkl(srcs, deps, name='', linkstatic=0, tags=[], size="medium", - args=None): - tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args) -def tf_cc_tests_gpu(srcs, deps, name='', linkstatic=0, tags=[], size="medium", +def tf_cc_test_mkl(srcs, + deps, + name="", + linkstatic=0, + tags=[], + size="medium", + args=None): + if_mkl(tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args)) + + +def tf_cc_tests_gpu(srcs, + deps, + name="", + linkstatic=0, + tags=[], + size="medium", args=None): tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args) -def tf_cuda_cc_tests(srcs, deps, name='', tags=[], size="medium", linkstatic=0, - args=None, linkopts=[]): +def tf_cuda_cc_tests(srcs, + deps, + name="", + tags=[], + size="medium", + linkstatic=0, + args=None, + linkopts=[]): for src in srcs: tf_cuda_cc_test( name=src_to_test_name(src), @@ -405,48 +550,52 @@ def tf_cuda_cc_tests(srcs, deps, name='', tags=[], size="medium", linkstatic=0, args=args, linkopts=linkopts) + def _cuda_copts(): - """Gets the appropriate set of copts for (maybe) CUDA compilation. + """Gets the appropriate set of copts for (maybe) CUDA compilation. If we're doing CUDA compilation, returns copts for our particular CUDA compiler. If we're not doing CUDA compilation, returns an empty list. """ - return cuda_default_copts() + select({ - "//conditions:default": [], - "@local_config_cuda//cuda:using_nvcc": ( - [ - "-nvcc_options=relaxed-constexpr", - "-nvcc_options=ftz=true", - ] - ), - "@local_config_cuda//cuda:using_clang": ( - [ - "-fcuda-flush-denormals-to-zero", - ] - ), - }) + return cuda_default_copts() + select({ + "//conditions:default": [], + "@local_config_cuda//cuda:using_nvcc": ([ + "-nvcc_options=relaxed-constexpr", + "-nvcc_options=ftz=true", + ]), + "@local_config_cuda//cuda:using_clang": ([ + "-fcuda-flush-denormals-to-zero", + ]), + }) + # Build defs for TensorFlow kernels + # When this target is built using --config=cuda, a cc_library is built # that passes -DGOOGLE_CUDA=1 and '-x cuda', linking in additional # libraries needed by GPU kernels. -def tf_gpu_kernel_library(srcs, copts=[], cuda_copts=[], deps=[], hdrs=[], +def tf_gpu_kernel_library(srcs, + copts=[], + cuda_copts=[], + deps=[], + hdrs=[], **kwargs): copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts() native.cc_library( - srcs = srcs, - hdrs = hdrs, - copts = copts, - deps = deps + if_cuda([ - "//tensorflow/core:cuda", - "//tensorflow/core:gpu_lib", + srcs=srcs, + hdrs=hdrs, + copts=copts, + deps=deps + if_cuda([ + clean_dep("//tensorflow/core:cuda"), + clean_dep("//tensorflow/core:gpu_lib"), ]), alwayslink=1, **kwargs) + def tf_cuda_library(deps=None, cuda_deps=None, copts=None, **kwargs): """Generate a cc_library with a conditional set of CUDA dependencies. @@ -471,15 +620,23 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=None, **kwargs): copts = [] native.cc_library( - deps = deps + if_cuda(cuda_deps + [ - "//tensorflow/core:cuda", + deps=deps + if_cuda(cuda_deps + [ + clean_dep("//tensorflow/core:cuda"), "@local_config_cuda//cuda:cuda_headers" ]), - copts = copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]), + copts=copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]), **kwargs) -def tf_kernel_library(name, prefix=None, srcs=None, gpu_srcs=None, hdrs=None, - deps=None, alwayslink=1, copts=tf_copts(), **kwargs): + +def tf_kernel_library(name, + prefix=None, + srcs=None, + gpu_srcs=None, + hdrs=None, + deps=None, + alwayslink=1, + copts=tf_copts(), + **kwargs): """A rule to build a TensorFlow OpKernel. May either specify srcs/hdrs or prefix. Similar to tf_cuda_library, @@ -509,37 +666,59 @@ def tf_kernel_library(name, prefix=None, srcs=None, gpu_srcs=None, hdrs=None, deps = [] if prefix: - if native.glob([prefix + "*.cu.cc"], exclude = ["*test*"]): + if native.glob([prefix + "*.cu.cc"], exclude=["*test*"]): if not gpu_srcs: gpu_srcs = [] - gpu_srcs = gpu_srcs + native.glob([prefix + "*.cu.cc", prefix + "*.h"], - exclude = ["*test*"]) - srcs = srcs + native.glob([prefix + "*.cc"], - exclude = ["*test*", "*.cu.cc"]) - hdrs = hdrs + native.glob([prefix + "*.h"], exclude = ["*test*", "*.cu.h"]) + gpu_srcs = gpu_srcs + native.glob( + [prefix + "*.cu.cc", prefix + "*.h"], exclude=[prefix + "*test*"]) + srcs = srcs + native.glob( + [prefix + "*.cc"], exclude=[prefix + "*test*", prefix + "*.cu.cc"]) + hdrs = hdrs + native.glob( + [prefix + "*.h"], exclude=[prefix + "*test*", prefix + "*.cu.h"]) - cuda_deps = ["//tensorflow/core:gpu_lib"] + cuda_deps = [clean_dep("//tensorflow/core:gpu_lib")] if gpu_srcs: for gpu_src in gpu_srcs: if gpu_src.endswith(".cc") and not gpu_src.endswith(".cu.cc"): - fail("{} not allowed in gpu_srcs. .cc sources must end with .cu.cc".format(gpu_src)) + fail("{} not allowed in gpu_srcs. .cc sources must end with .cu.cc". + format(gpu_src)) tf_gpu_kernel_library( - name = name + "_gpu", - srcs = gpu_srcs, - deps = deps, - **kwargs) + name=name + "_gpu", srcs=gpu_srcs, deps=deps, **kwargs) cuda_deps.extend([":" + name + "_gpu"]) tf_cuda_library( - name = name, - srcs = srcs, - hdrs = hdrs, - copts = copts, - cuda_deps = cuda_deps, - linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 - alwayslink = alwayslink, - deps = deps, + name=name, + srcs=srcs, + hdrs=hdrs, + copts=copts, + cuda_deps=cuda_deps, + linkstatic=1, # Needed since alwayslink is broken in bazel b/27630669 + alwayslink=alwayslink, + deps=deps, **kwargs) + +def tf_mkl_kernel_library(name, + prefix=None, + srcs=None, + gpu_srcs=None, + hdrs=None, + deps=None, + alwayslink=1, + copts=tf_copts(), + **kwargs): + if_mkl( + tf_kernel_library( + name, + prefix=prefix, + srcs=srcs, + gpu_srcs=gpu_srcs, + hdrs=hdrs, + deps=deps, + alwayslink=alwayslink, + copts=copts, + **kwargs)) + + # Bazel rules for building swig files. def _py_wrap_cc_impl(ctx): srcs = ctx.files.srcs @@ -555,59 +734,61 @@ def _py_wrap_cc_impl(ctx): inputs += ctx.files.toolchain_deps swig_include_dirs = set(_get_repository_roots(ctx, inputs)) swig_include_dirs += sorted([f.dirname for f in ctx.files._swiglib]) - args = ["-c++", - "-python", - "-module", module_name, - "-o", ctx.outputs.cc_out.path, - "-outdir", ctx.outputs.py_out.dirname] + args = [ + "-c++", "-python", "-module", module_name, "-o", ctx.outputs.cc_out.path, + "-outdir", ctx.outputs.py_out.dirname + ] args += ["-l" + f.path for f in ctx.files.swig_includes] args += ["-I" + i for i in swig_include_dirs] args += [src.path] - outputs = [ctx.outputs.cc_out, - ctx.outputs.py_out] - ctx.action(executable=ctx.executable._swig, - arguments=args, - inputs=list(inputs), - outputs=outputs, - mnemonic="PythonSwig", - progress_message="SWIGing " + src.path) + outputs = [ctx.outputs.cc_out, ctx.outputs.py_out] + ctx.action( + executable=ctx.executable._swig, + arguments=args, + inputs=list(inputs), + outputs=outputs, + mnemonic="PythonSwig", + progress_message="SWIGing " + src.path) return struct(files=set(outputs)) + _py_wrap_cc = rule( - attrs = { - "srcs": attr.label_list( - mandatory = True, - allow_files = True, - ), - "swig_includes": attr.label_list( - cfg = "data", - allow_files = True, - ), - "deps": attr.label_list( - allow_files = True, - providers = ["cc"], - ), - "toolchain_deps": attr.label_list( - allow_files = True, - ), - "module_name": attr.string(mandatory = True), - "py_module_name": attr.string(mandatory = True), - "_swig": attr.label( - default = Label("@swig//:swig"), - executable = True, - cfg = "host", - ), - "_swiglib": attr.label( - default = Label("@swig//:templates"), - allow_files = True, - ), + attrs={ + "srcs": + attr.label_list( + mandatory=True, + allow_files=True,), + "swig_includes": + attr.label_list( + cfg="data", + allow_files=True,), + "deps": + attr.label_list( + allow_files=True, + providers=["cc"],), + "toolchain_deps": + attr.label_list( + allow_files=True,), + "module_name": + attr.string(mandatory=True), + "py_module_name": + attr.string(mandatory=True), + "_swig": + attr.label( + default=Label("@swig//:swig"), + executable=True, + cfg="host",), + "_swiglib": + attr.label( + default=Label("@swig//:templates"), + allow_files=True,), }, - outputs = { + outputs={ "cc_out": "%{module_name}.cc", "py_out": "%{py_module_name}.py", }, - implementation = _py_wrap_cc_impl, -) + implementation=_py_wrap_cc_impl,) + def _get_repository_roots(ctx, files): """Returns abnormal root directories under which files reside. @@ -638,6 +819,7 @@ def _get_repository_roots(ctx, files): result[root] -= 1 return [k for v, k in sorted([(v, k) for k, v in result.items()])] + # Bazel rule for collecting the header files that a target depends on. def _transitive_hdrs_impl(ctx): outputs = set() @@ -645,38 +827,36 @@ def _transitive_hdrs_impl(ctx): outputs += dep.cc.transitive_headers return struct(files=outputs) + _transitive_hdrs = rule( - attrs = { + attrs={ "deps": attr.label_list( - allow_files = True, - providers = ["cc"], - ), + allow_files=True, + providers=["cc"],), }, - implementation = _transitive_hdrs_impl, -) + implementation=_transitive_hdrs_impl,) + def transitive_hdrs(name, deps=[], **kwargs): - _transitive_hdrs(name=name + "_gather", - deps=deps) - native.filegroup(name=name, - srcs=[":" + name + "_gather"]) + _transitive_hdrs(name=name + "_gather", deps=deps) + native.filegroup(name=name, srcs=[":" + name + "_gather"]) + # Create a header only library that includes all the headers exported by # the libraries in deps. def cc_header_only_library(name, deps=[], **kwargs): - _transitive_hdrs(name=name + "_gather", - deps=deps) - native.cc_library(name=name, - hdrs=[":" + name + "_gather"], - **kwargs) + _transitive_hdrs(name=name + "_gather", deps=deps) + native.cc_library(name=name, hdrs=[":" + name + "_gather"], **kwargs) + def tf_custom_op_library_additional_deps(): return [ - "@protobuf//:protobuf", - "//third_party/eigen3", - "//tensorflow/core:framework_headers_lib", + "@protobuf//:protobuf_headers", + clean_dep("//third_party/eigen3"), + clean_dep("//tensorflow/core:framework_headers_lib"), ] + # Traverse the dependency graph along the "deps" attribute of the # target and return a struct with one field called 'tf_collected_deps'. # tf_collected_deps will be the union of the deps of the current target @@ -690,14 +870,16 @@ def _collect_deps_aspect_impl(target, ctx): alldeps = alldeps | dep.tf_collected_deps return struct(tf_collected_deps=alldeps) + collect_deps_aspect = aspect( - implementation=_collect_deps_aspect_impl, - attr_aspects=["deps"]) + implementation=_collect_deps_aspect_impl, attr_aspects=["deps"]) + def _dep_label(dep): label = dep.label return label.package + ":" + label.name + # This rule checks that the transitive dependencies of targets listed # in the 'deps' attribute don't depend on the targets listed in # the 'disallowed_deps' attribute. @@ -709,137 +891,177 @@ def _check_deps_impl(ctx): for dep in input_dep.tf_collected_deps: for disallowed_dep in disallowed_deps: if dep == disallowed_dep.label: - fail(_dep_label(input_dep) + " cannot depend on " + - _dep_label(disallowed_dep)) + fail( + _dep_label(input_dep) + " cannot depend on " + _dep_label( + disallowed_dep)) return struct() + check_deps = rule( _check_deps_impl, - attrs = { - "deps": attr.label_list( - aspects=[collect_deps_aspect], - mandatory = True, - allow_files = True - ), - "disallowed_deps": attr.label_list( - mandatory = True, - allow_files = True - )}, -) + attrs={ + "deps": + attr.label_list( + aspects=[collect_deps_aspect], mandatory=True, + allow_files=True), + "disallowed_deps": + attr.label_list(mandatory=True, allow_files=True) + },) + # Helper to build a dynamic library (.so) from the sources containing # implementations of custom ops and kernels. def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[]): cuda_deps = [ - "//tensorflow/core:stream_executor_headers_lib", + clean_dep("//tensorflow/core:stream_executor_headers_lib"), "@local_config_cuda//cuda:cudart_static", ] deps = deps + tf_custom_op_library_additional_deps() if gpu_srcs: basename = name.split(".")[0] native.cc_library( - name = basename + "_gpu", - srcs = gpu_srcs, - copts = _cuda_copts(), - deps = deps + if_cuda(cuda_deps)) + name=basename + "_gpu", + srcs=gpu_srcs, + copts=_cuda_copts(), + deps=deps + if_cuda(cuda_deps)) cuda_deps.extend([":" + basename + "_gpu"]) - check_deps(name=name+"_check_deps", - deps=deps + if_cuda(cuda_deps), - disallowed_deps=["//tensorflow/core:framework", - "//tensorflow/core:lib"]) + check_deps( + name=name + "_check_deps", + deps=deps + if_cuda(cuda_deps), + disallowed_deps=[ + clean_dep("//tensorflow/core:framework"), + clean_dep("//tensorflow/core:lib") + ]) + + native.cc_binary( + name=name, + srcs=srcs, + deps=deps + if_cuda(cuda_deps), + data=[name + "_check_deps"], + copts=tf_copts(), + linkshared=1, + linkopts=select({ + "//conditions:default": [ + "-lm", + ], + clean_dep("//tensorflow:darwin"): [], + }),) + + +def tf_custom_op_py_library(name, + srcs=[], + dso=[], + kernels=[], + srcs_version="PY2AND3", + visibility=None, + deps=[]): + kernels = kernels # unused argument + native.py_library( + name=name, + data=dso, + srcs=srcs, + srcs_version=srcs_version, + visibility=visibility, + deps=deps,) - native.cc_binary(name=name, - srcs=srcs, - deps=deps + if_cuda(cuda_deps), - data=[name + "_check_deps"], - copts=tf_copts(), - linkshared=1, - linkopts = select({ - "//conditions:default": [ - "-lm", - ], - "//tensorflow:darwin": [], - }), - ) def tf_extension_linkopts(): return [] # No extension link opts + def tf_extension_copts(): return [] # No extension c opts -def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs): + +def tf_py_wrap_cc(name, + srcs, + swig_includes=[], + deps=[], + copts=[], + **kwargs): module_name = name.split("/")[-1] # Convert a rule name such as foo/bar/baz to foo/bar/_baz.so # and use that as the name for the rule producing the .so file. cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"]) - cc_library_pyd_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".pyd"]) + cc_library_pyd_name = "/".join( + name.split("/")[:-1] + ["_" + module_name + ".pyd"]) extra_deps = [] - _py_wrap_cc(name=name + "_py_wrap", - srcs=srcs, - swig_includes=swig_includes, - deps=deps + extra_deps, - toolchain_deps=["//tools/defaults:crosstool"], - module_name=module_name, - py_module_name=name) + _py_wrap_cc( + name=name + "_py_wrap", + srcs=srcs, + swig_includes=swig_includes, + deps=deps + extra_deps, + toolchain_deps=["//tools/defaults:crosstool"], + module_name=module_name, + py_module_name=name) extra_linkopts = select({ "@local_config_cuda//cuda:darwin": [ "-Wl,-exported_symbols_list", - "//tensorflow:tf_exported_symbols.lds" - ], - "//tensorflow:windows": [ + clean_dep("//tensorflow:tf_exported_symbols.lds") ], + clean_dep("//tensorflow:windows"): [], + clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-Wl,--version-script", - "//tensorflow:tf_version_script.lds" - ]}) + clean_dep("//tensorflow:tf_version_script.lds") + ] + }) extra_deps += select({ "@local_config_cuda//cuda:darwin": [ - "//tensorflow:tf_exported_symbols.lds" - ], - "//tensorflow:windows": [ + clean_dep("//tensorflow:tf_exported_symbols.lds") ], + clean_dep("//tensorflow:windows"): [], + clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ - "//tensorflow:tf_version_script.lds" + clean_dep("//tensorflow:tf_version_script.lds") ] }) native.cc_binary( name=cc_library_name, srcs=[module_name + ".cc"], - copts=(copts + ["-Wno-self-assign", - "-Wno-sign-compare", - "-Wno-write-strings"] - + tf_extension_copts()), + copts=(copts + [ + "-Wno-self-assign", "-Wno-sign-compare", "-Wno-write-strings" + ] + tf_extension_copts()), linkopts=tf_extension_linkopts() + extra_linkopts, linkstatic=1, linkshared=1, deps=deps + extra_deps) native.genrule( - name = "gen_" + cc_library_pyd_name, - srcs = [":" + cc_library_name], - outs = [cc_library_pyd_name], - cmd = "cp $< $@", - ) - native.py_library(name=name, - srcs=[":" + name + ".py"], - srcs_version="PY2AND3", - data=select({ - "//tensorflow:windows": [":" + cc_library_pyd_name], - "//conditions:default": [":" + cc_library_name], - })) + name="gen_" + cc_library_pyd_name, + srcs=[":" + cc_library_name], + outs=[cc_library_pyd_name], + cmd="cp $< $@",) + native.py_library( + name=name, + srcs=[":" + name + ".py"], + srcs_version="PY2AND3", + data=select({ + clean_dep("//tensorflow:windows"): [":" + cc_library_pyd_name], + "//conditions:default": [":" + cc_library_name], + })) + def py_test(deps=[], **kwargs): native.py_test( deps=select({ - "//conditions:default" : deps, - "//tensorflow:no_tensorflow_py_deps" : [] + "//conditions:default": deps, + clean_dep("//tensorflow:no_tensorflow_py_deps"): [] }), **kwargs) -def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[], - tags=[], shard_count=1, additional_deps=[], flaky=0, + +def tf_py_test(name, + srcs, + size="medium", + data=[], + main=None, + args=[], + tags=[], + shard_count=1, + additional_deps=[], + flaky=0, xla_enabled=False): if xla_enabled: additional_deps += tf_additional_xla_deps_py() @@ -850,50 +1072,71 @@ def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[], main=main, args=args, tags=tags, - visibility=["//tensorflow:internal"], + visibility=[clean_dep("//tensorflow:internal")], shard_count=shard_count, data=data, deps=select({ - "//conditions:default" : [ - "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:gradient_checker", + "//conditions:default": [ + clean_dep("//tensorflow/python:extra_py_tests_deps"), + clean_dep("//tensorflow/python:gradient_checker"), ] + additional_deps, - "//tensorflow:no_tensorflow_py_deps" : [] + clean_dep("//tensorflow:no_tensorflow_py_deps"): [] }), flaky=flaky, srcs_version="PY2AND3") -def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[], - shard_count=1, additional_deps=[], tags=[], flaky=0, + +def cuda_py_test(name, + srcs, + size="medium", + data=[], + main=None, + args=[], + shard_count=1, + additional_deps=[], + tags=[], + flaky=0, xla_enabled=False): test_tags = tags + tf_cuda_tests_tags() - tf_py_test(name=name, - size=size, - srcs=srcs, - data=data, - main=main, - args=args, - tags=test_tags, - shard_count=shard_count, - additional_deps=additional_deps, - flaky=flaky, - xla_enabled=xla_enabled) + tf_py_test( + name=name, + size=size, + srcs=srcs, + data=data, + main=main, + args=args, + tags=test_tags, + shard_count=shard_count, + additional_deps=additional_deps, + flaky=flaky, + xla_enabled=xla_enabled) -def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[], - shard_count=1, additional_deps=[], tags=[], flaky=0, + +def sycl_py_test(name, + srcs, + size="medium", + data=[], + main=None, + args=[], + shard_count=1, + additional_deps=[], + tags=[], + flaky=0, xla_enabled=False): - test_tags = tags + tf_sycl_tests_tags() - tf_py_test(name=name, - size=size, - srcs=srcs, - data=data, - main=main, - args=args, - tags=test_tags, - shard_count=shard_count, - additional_deps=additional_deps, - flaky=flaky, - xla_enabled=xla_enabled) + test_tags = tags + tf_sycl_tests_tags() + tf_py_test( + name=name, + size=size, + srcs=srcs, + data=data, + main=main, + args=args, + tags=test_tags, + shard_count=shard_count, + additional_deps=additional_deps, + flaky=flaky, + xla_enabled=xla_enabled) + def py_tests(name, srcs, @@ -908,22 +1151,39 @@ def py_tests(name, test_name = src.split("/")[-1].split(".")[0] if prefix: test_name = "%s_%s" % (prefix, test_name) - tf_py_test(name=test_name, - size=size, - srcs=[src], - main=src, - tags=tags, - shard_count=shard_count, - data=data, - additional_deps=additional_deps, - xla_enabled=xla_enabled) + tf_py_test( + name=test_name, + size=size, + srcs=[src], + main=src, + tags=tags, + shard_count=shard_count, + data=data, + additional_deps=additional_deps, + xla_enabled=xla_enabled) -def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[], - shard_count=1, tags=[], prefix="", xla_enabled=False): + +def cuda_py_tests(name, + srcs, + size="medium", + additional_deps=[], + data=[], + shard_count=1, + tags=[], + prefix="", + xla_enabled=False): test_tags = tags + tf_cuda_tests_tags() - py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps, - data=data, tags=test_tags, shard_count=shard_count,prefix=prefix, - xla_enabled=xla_enabled) + py_tests( + name=name, + size=size, + srcs=srcs, + additional_deps=additional_deps, + data=data, + tags=test_tags, + shard_count=shard_count, + prefix=prefix, + xla_enabled=xla_enabled) + # Creates a genrule named for running tools/proto_text's generator to # make the proto_text functions, for the protos passed in . @@ -931,40 +1191,46 @@ def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[], # Return a struct with fields (hdrs, srcs) containing the names of the # generated files. def tf_generate_proto_text_sources(name, srcs_relative_dir, srcs): - out_hdrs = ([p.replace(".proto", ".pb_text.h") for p in srcs] + - [p.replace(".proto", ".pb_text-impl.h") for p in srcs]) + out_hdrs = ( + [p.replace(".proto", ".pb_text.h") + for p in srcs] + [p.replace(".proto", ".pb_text-impl.h") for p in srcs]) out_srcs = [p.replace(".proto", ".pb_text.cc") for p in srcs] native.genrule( - name = name, - srcs = srcs + ["//tensorflow/tools/proto_text:placeholder.txt"], - outs = out_hdrs + out_srcs, - cmd = "$(location //tensorflow/tools/proto_text:gen_proto_text_functions) " + - "$(@D) " + srcs_relative_dir + " $(SRCS)", - tools = ["//tensorflow/tools/proto_text:gen_proto_text_functions"], - ) + name=name, + srcs=srcs + [clean_dep("//tensorflow/tools/proto_text:placeholder.txt")], + outs=out_hdrs + out_srcs, + cmd= + "$(location //tensorflow/tools/proto_text:gen_proto_text_functions) " + + "$(@D) " + srcs_relative_dir + " $(SRCS)", + tools=[ + clean_dep("//tensorflow/tools/proto_text:gen_proto_text_functions") + ],) return struct(hdrs=out_hdrs, srcs=out_srcs) + def tf_genrule_cmd_append_to_srcs(to_append): - return ("cat $(SRCS) > $(@) && " + - "echo >> $(@) && " + - "echo " + to_append + " >> $(@)") + return ("cat $(SRCS) > $(@) && " + "echo >> $(@) && " + "echo " + to_append + + " >> $(@)") def tf_version_info_genrule(): native.genrule( - name = "version_info_gen", - srcs = [ - "//tensorflow/tools/git:gen/spec.json", - "//tensorflow/tools/git:gen/head", - "//tensorflow/tools/git:gen/branch_ref", + name="version_info_gen", + srcs=[ + clean_dep("//tensorflow/tools/git:gen/spec.json"), + clean_dep("//tensorflow/tools/git:gen/head"), + clean_dep("//tensorflow/tools/git:gen/branch_ref"), ], - outs = ["util/version_info.cc"], - cmd = "$(location //tensorflow/tools/git:gen_git_source.py) --generate $(SRCS) \"$@\"", - local = 1, - tools = ["//tensorflow/tools/git:gen_git_source.py"], - ) + outs=["util/version_info.cc"], + cmd= + "$(location //tensorflow/tools/git:gen_git_source.py) --generate $(SRCS) \"$@\"", + local=1, + tools=[clean_dep("//tensorflow/tools/git:gen_git_source.py")],) -def cc_library_with_android_deps(deps, android_deps=[], - common_deps=[], **kwargs): + +def cc_library_with_android_deps(deps, + android_deps=[], + common_deps=[], + **kwargs): deps = if_not_android(deps) + if_android(android_deps) + common_deps native.cc_library(deps=deps, **kwargs) diff --git a/tensorflow/tf_exported_symbols.lds b/tensorflow/tf_exported_symbols.lds index cb81e89922c..1f4d900ec2b 100644 --- a/tensorflow/tf_exported_symbols.lds +++ b/tensorflow/tf_exported_symbols.lds @@ -1,3 +1,4 @@ *tensorflow* *perftools*gputools* *tf_* +TF_* diff --git a/tensorflow/tf_version_script.lds b/tensorflow/tf_version_script.lds index 8c8c8be5a93..b368f7cf21d 100644 --- a/tensorflow/tf_version_script.lds +++ b/tensorflow/tf_version_script.lds @@ -2,6 +2,7 @@ tensorflow { global: *tensorflow*; *perftools*gputools*; + TF_*; local: *; }; diff --git a/tensorflow/tools/api/golden/BUILD b/tensorflow/tools/api/golden/BUILD new file mode 100644 index 00000000000..08436396a6c --- /dev/null +++ b/tensorflow/tools/api/golden/BUILD @@ -0,0 +1,24 @@ +# TensorFlow API backwards compatibility test goldens. + +package( + default_visibility = ["//tensorflow/tools/api:__subpackages__"], +) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "api_golden", + srcs = glob(["*.pbtxt"]), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tools/api/golden/tensorflow.-aggregation-method.pbtxt b/tensorflow/tools/api/golden/tensorflow.-aggregation-method.pbtxt new file mode 100644 index 00000000000..f79029d3fe0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-aggregation-method.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.AggregationMethod" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ADD_N" + mtype: "" + } + member { + name: "DEFAULT" + mtype: "" + } + member { + name: "EXPERIMENTAL_ACCUMULATE_N" + mtype: "" + } + member { + name: "EXPERIMENTAL_TREE" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt new file mode 100644 index 00000000000..0fb1aaba283 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt @@ -0,0 +1,108 @@ +path: "tensorflow.AttrValue.ListValue" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "B_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FUNC_FIELD_NUMBER" + mtype: "" + } + member { + name: "F_FIELD_NUMBER" + mtype: "" + } + member { + name: "I_FIELD_NUMBER" + mtype: "" + } + member { + name: "SHAPE_FIELD_NUMBER" + mtype: "" + } + member { + name: "S_FIELD_NUMBER" + mtype: "" + } + member { + name: "TENSOR_FIELD_NUMBER" + mtype: "" + } + member { + name: "TYPE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt new file mode 100644 index 00000000000..e7a3a1f02fa --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt @@ -0,0 +1,120 @@ +path: "tensorflow.AttrValue" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "B_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FUNC_FIELD_NUMBER" + mtype: "" + } + member { + name: "F_FIELD_NUMBER" + mtype: "" + } + member { + name: "I_FIELD_NUMBER" + mtype: "" + } + member { + name: "LIST_FIELD_NUMBER" + mtype: "" + } + member { + name: "ListValue" + mtype: "" + } + member { + name: "PLACEHOLDER_FIELD_NUMBER" + mtype: "" + } + member { + name: "SHAPE_FIELD_NUMBER" + mtype: "" + } + member { + name: "S_FIELD_NUMBER" + mtype: "" + } + member { + name: "TENSOR_FIELD_NUMBER" + mtype: "" + } + member { + name: "TYPE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-conditional-accumulator-base.pbtxt b/tensorflow/tools/api/golden/tensorflow.-conditional-accumulator-base.pbtxt new file mode 100644 index 00000000000..c9a32c16b34 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-conditional-accumulator-base.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.ConditionalAccumulatorBase" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "accumulator_ref" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'shape\', \'accumulator_ref\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "num_accumulated" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_global_step" + argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/tensorflow.-conditional-accumulator.pbtxt new file mode 100644 index 00000000000..d23b3bd0cae --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-conditional-accumulator.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.ConditionalAccumulator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "accumulator_ref" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], " + } + member_method { + name: "apply_grad" + argspec: "args=[\'self\', \'grad\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " + } + member_method { + name: "num_accumulated" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_global_step" + argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "take_grad" + argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt new file mode 100644 index 00000000000..29bb3be35cb --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.ConfigProto.DeviceCountEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt new file mode 100644 index 00000000000..da6af3919e9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt @@ -0,0 +1,136 @@ +path: "tensorflow.ConfigProto" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ALLOW_SOFT_PLACEMENT_FIELD_NUMBER" + mtype: "" + } + member { + name: "CLUSTER_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "DEVICE_COUNT_FIELD_NUMBER" + mtype: "" + } + member { + name: "DEVICE_FILTERS_FIELD_NUMBER" + mtype: "" + } + member { + name: "DeviceCountEntry" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "GPU_OPTIONS_FIELD_NUMBER" + mtype: "" + } + member { + name: "GRAPH_OPTIONS_FIELD_NUMBER" + mtype: "" + } + member { + name: "INTER_OP_PARALLELISM_THREADS_FIELD_NUMBER" + mtype: "" + } + member { + name: "INTRA_OP_PARALLELISM_THREADS_FIELD_NUMBER" + mtype: "" + } + member { + name: "LOG_DEVICE_PLACEMENT_FIELD_NUMBER" + mtype: "" + } + member { + name: "OPERATION_TIMEOUT_IN_MS_FIELD_NUMBER" + mtype: "" + } + member { + name: "PLACEMENT_PERIOD_FIELD_NUMBER" + mtype: "" + } + member { + name: "RPC_OPTIONS_FIELD_NUMBER" + mtype: "" + } + member { + name: "SESSION_INTER_OP_THREAD_POOL_FIELD_NUMBER" + mtype: "" + } + member { + name: "USE_PER_SESSION_THREADS_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-d-type.pbtxt b/tensorflow/tools/api/golden/tensorflow.-d-type.pbtxt new file mode 100644 index 00000000000..0b5b88bba80 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-d-type.pbtxt @@ -0,0 +1,77 @@ +path: "tensorflow.DType" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "as_datatype_enum" + mtype: "" + } + member { + name: "as_numpy_dtype" + mtype: "" + } + member { + name: "base_dtype" + mtype: "" + } + member { + name: "is_bool" + mtype: "" + } + member { + name: "is_complex" + mtype: "" + } + member { + name: "is_floating" + mtype: "" + } + member { + name: "is_integer" + mtype: "" + } + member { + name: "is_numpy_compatible" + mtype: "" + } + member { + name: "is_quantized" + mtype: "" + } + member { + name: "is_unsigned" + mtype: "" + } + member { + name: "limits" + mtype: "" + } + member { + name: "max" + mtype: "" + } + member { + name: "min" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "real_dtype" + mtype: "" + } + member { + name: "size" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-device-spec.pbtxt b/tensorflow/tools/api/golden/tensorflow.-device-spec.pbtxt new file mode 100644 index 00000000000..92e535c3414 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-device-spec.pbtxt @@ -0,0 +1,37 @@ +path: "tensorflow.DeviceSpec" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "job" + mtype: "" + } + member { + name: "replica" + mtype: "" + } + member { + name: "task" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'job\', \'replica\', \'task\', \'device_type\', \'device_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "from_string" + argspec: "args=[\'spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge_from" + argspec: "args=[\'self\', \'dev\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "parse_from_string" + argspec: "args=[\'self\', \'spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "to_string" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-dimension.pbtxt b/tensorflow/tools/api/golden/tensorflow.-dimension.pbtxt new file mode 100644 index 00000000000..a9ab27719b4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-dimension.pbtxt @@ -0,0 +1,25 @@ +path: "tensorflow.Dimension" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "value" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-event.pbtxt b/tensorflow/tools/api/golden/tensorflow.-event.pbtxt new file mode 100644 index 00000000000..9bf8c124288 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-event.pbtxt @@ -0,0 +1,112 @@ +path: "tensorflow.Event" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FILE_VERSION_FIELD_NUMBER" + mtype: "" + } + member { + name: "GRAPH_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "LOG_MESSAGE_FIELD_NUMBER" + mtype: "" + } + member { + name: "META_GRAPH_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "SESSION_LOG_FIELD_NUMBER" + mtype: "" + } + member { + name: "STEP_FIELD_NUMBER" + mtype: "" + } + member { + name: "SUMMARY_FIELD_NUMBER" + mtype: "" + } + member { + name: "TAGGED_RUN_METADATA_FIELD_NUMBER" + mtype: "" + } + member { + name: "WALL_TIME_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt new file mode 100644 index 00000000000..72cc5324476 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt @@ -0,0 +1,62 @@ +path: "tensorflow.FIFOQueue" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'capacity\', \'dtypes\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'fifo_queue\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-len-feature.pbtxt b/tensorflow/tools/api/golden/tensorflow.-fixed-len-feature.pbtxt new file mode 100644 index 00000000000..6933814a7b6 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-fixed-len-feature.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.FixedLenFeature" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "default_value" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-len-sequence-feature.pbtxt b/tensorflow/tools/api/golden/tensorflow.-fixed-len-sequence-feature.pbtxt new file mode 100644 index 00000000000..c5387879519 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-fixed-len-sequence-feature.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.FixedLenSequenceFeature" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_missing" + mtype: "" + } + member { + name: "default_value" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt new file mode 100644 index 00000000000..5c77b3dd5cc --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.FixedLengthRecordReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt new file mode 100644 index 00000000000..30f7e4e1165 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt @@ -0,0 +1,108 @@ +path: "tensorflow.GPUOptions" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ALLOCATOR_TYPE_FIELD_NUMBER" + mtype: "" + } + member { + name: "ALLOW_GROWTH_FIELD_NUMBER" + mtype: "" + } + member { + name: "DEFERRED_DELETION_BYTES_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FORCE_GPU_COMPATIBLE_FIELD_NUMBER" + mtype: "" + } + member { + name: "PER_PROCESS_GPU_MEMORY_FRACTION_FIELD_NUMBER" + mtype: "" + } + member { + name: "POLLING_ACTIVE_DELAY_USECS_FIELD_NUMBER" + mtype: "" + } + member { + name: "POLLING_INACTIVE_DELAY_MSECS_FIELD_NUMBER" + mtype: "" + } + member { + name: "VISIBLE_DEVICE_LIST_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt new file mode 100644 index 00000000000..1495e847cb0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt @@ -0,0 +1,92 @@ +path: "tensorflow.GraphDef" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "LIBRARY_FIELD_NUMBER" + mtype: "" + } + member { + name: "NODE_FIELD_NUMBER" + mtype: "" + } + member { + name: "VERSIONS_FIELD_NUMBER" + mtype: "" + } + member { + name: "VERSION_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-keys.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph-keys.pbtxt new file mode 100644 index 00000000000..ef2cfe3787e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-graph-keys.pbtxt @@ -0,0 +1,136 @@ +path: "tensorflow.GraphKeys" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ACTIVATIONS" + mtype: "" + } + member { + name: "ASSET_FILEPATHS" + mtype: "" + } + member { + name: "BIASES" + mtype: "" + } + member { + name: "CONCATENATED_VARIABLES" + mtype: "" + } + member { + name: "COND_CONTEXT" + mtype: "" + } + member { + name: "EVAL_STEP" + mtype: "" + } + member { + name: "GLOBAL_STEP" + mtype: "" + } + member { + name: "GLOBAL_VARIABLES" + mtype: "" + } + member { + name: "INIT_OP" + mtype: "" + } + member { + name: "LOCAL_INIT_OP" + mtype: "" + } + member { + name: "LOCAL_RESOURCES" + mtype: "" + } + member { + name: "LOCAL_VARIABLES" + mtype: "" + } + member { + name: "LOSSES" + mtype: "" + } + member { + name: "MODEL_VARIABLES" + mtype: "" + } + member { + name: "MOVING_AVERAGE_VARIABLES" + mtype: "" + } + member { + name: "QUEUE_RUNNERS" + mtype: "" + } + member { + name: "READY_FOR_LOCAL_INIT_OP" + mtype: "" + } + member { + name: "READY_OP" + mtype: "" + } + member { + name: "REGULARIZATION_LOSSES" + mtype: "" + } + member { + name: "RESOURCES" + mtype: "" + } + member { + name: "SAVEABLE_OBJECTS" + mtype: "" + } + member { + name: "SAVERS" + mtype: "" + } + member { + name: "SUMMARIES" + mtype: "" + } + member { + name: "SUMMARY_OP" + mtype: "" + } + member { + name: "TABLE_INITIALIZERS" + mtype: "" + } + member { + name: "TRAINABLE_RESOURCE_VARIABLES" + mtype: "" + } + member { + name: "TRAINABLE_VARIABLES" + mtype: "" + } + member { + name: "TRAIN_OP" + mtype: "" + } + member { + name: "UPDATE_OPS" + mtype: "" + } + member { + name: "VARIABLES" + mtype: "" + } + member { + name: "WEIGHTS" + mtype: "" + } + member { + name: "WHILE_CONTEXT" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt new file mode 100644 index 00000000000..0844f891cad --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt @@ -0,0 +1,112 @@ +path: "tensorflow.GraphOptions" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "BUILD_COST_MODEL_AFTER_FIELD_NUMBER" + mtype: "" + } + member { + name: "BUILD_COST_MODEL_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "ENABLE_BFLOAT16_SENDRECV_FIELD_NUMBER" + mtype: "" + } + member { + name: "ENABLE_RECV_SCHEDULING_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "INFER_SHAPES_FIELD_NUMBER" + mtype: "" + } + member { + name: "OPTIMIZER_OPTIONS_FIELD_NUMBER" + mtype: "" + } + member { + name: "PLACE_PRUNED_GRAPH_FIELD_NUMBER" + mtype: "" + } + member { + name: "REWRITE_OPTIONS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TIMELINE_STEP_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt new file mode 100644 index 00000000000..75361803a39 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt @@ -0,0 +1,137 @@ +path: "tensorflow.Graph" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "building_function" + mtype: "" + } + member { + name: "collections" + mtype: "" + } + member { + name: "finalized" + mtype: "" + } + member { + name: "graph_def_versions" + mtype: "" + } + member { + name: "seed" + mtype: "" + } + member { + name: "version" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_to_collection" + argspec: "args=[\'self\', \'name\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_to_collections" + argspec: "args=[\'self\', \'names\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_default" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_graph_def" + argspec: "args=[\'self\', \'from_version\', \'add_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "as_graph_element" + argspec: "args=[\'self\', \'obj\', \'allow_tensor\', \'allow_operation\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "clear_collection" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "colocate_with" + argspec: "args=[\'self\', \'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "container" + argspec: "args=[\'self\', \'container_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "control_dependencies" + argspec: "args=[\'self\', \'control_inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "create_op" + argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], " + } + member_method { + name: "device" + argspec: "args=[\'self\', \'device_name_or_function\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "finalize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_all_collection_keys" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_collection" + argspec: "args=[\'self\', \'name\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_collection_ref" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_name_scope" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_operation_by_name" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_operations" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_tensor_by_name" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "gradient_override_map" + argspec: "args=[\'self\', \'op_type_map\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_feedable" + argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_fetchable" + argspec: "args=[\'self\', \'tensor_or_op\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "name_scope" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prevent_feeding" + argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prevent_fetching" + argspec: "args=[\'self\', \'op\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "unique_name" + argspec: "args=[\'self\', \'name\', \'mark_as_used\'], varargs=None, keywords=None, defaults=[\'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt new file mode 100644 index 00000000000..2567d2fe602 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt @@ -0,0 +1,104 @@ +path: "tensorflow.HistogramProto" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "BUCKET_FIELD_NUMBER" + mtype: "" + } + member { + name: "BUCKET_LIMIT_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "MAX_FIELD_NUMBER" + mtype: "" + } + member { + name: "MIN_FIELD_NUMBER" + mtype: "" + } + member { + name: "NUM_FIELD_NUMBER" + mtype: "" + } + member { + name: "SUM_FIELD_NUMBER" + mtype: "" + } + member { + name: "SUM_SQUARES_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-identity-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-identity-reader.pbtxt new file mode 100644 index 00000000000..2eda320d636 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-identity-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.IdentityReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-indexed-slices.pbtxt b/tensorflow/tools/api/golden/tensorflow.-indexed-slices.pbtxt new file mode 100644 index 00000000000..fee84d85307 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-indexed-slices.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.IndexedSlices" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dense_shape" + mtype: "" + } + member { + name: "device" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "indices" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member { + name: "values" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'values\', \'indices\', \'dense_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt new file mode 100644 index 00000000000..f5b0bae58d0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.InteractiveSession" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "graph_def" + mtype: "" + } + member { + name: "sess_str" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'target\', \'graph\', \'config\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'None\'], " + } + member_method { + name: "as_default" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "list_devices" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_callable" + argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "partial_run" + argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "partial_run_setup" + argspec: "args=[\'self\', \'fetches\', \'feeds\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "run" + argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt b/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt new file mode 100644 index 00000000000..a43c5eb7e30 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt @@ -0,0 +1,112 @@ +path: "tensorflow.LogMessage" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DEBUGGING" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "ERROR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FATAL" + mtype: "" + } + member { + name: "INFO" + mtype: "" + } + member { + name: "LEVEL_FIELD_NUMBER" + mtype: "" + } + member { + name: "Level" + mtype: "" + } + member { + name: "MESSAGE_FIELD_NUMBER" + mtype: "" + } + member { + name: "UNKNOWN" + mtype: "" + } + member { + name: "WARN" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt new file mode 100644 index 00000000000..3572126fbfd --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.MetaGraphDef.CollectionDefEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt new file mode 100644 index 00000000000..ebf49f434ae --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt @@ -0,0 +1,100 @@ +path: "tensorflow.MetaGraphDef.MetaInfoDef" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ANY_INFO_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "META_GRAPH_VERSION_FIELD_NUMBER" + mtype: "" + } + member { + name: "STRIPPED_OP_LIST_FIELD_NUMBER" + mtype: "" + } + member { + name: "TAGS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TENSORFLOW_GIT_VERSION_FIELD_NUMBER" + mtype: "" + } + member { + name: "TENSORFLOW_VERSION_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt new file mode 100644 index 00000000000..48fccac99d6 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.MetaGraphDef.SignatureDefEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt new file mode 100644 index 00000000000..3e683a87159 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt @@ -0,0 +1,112 @@ +path: "tensorflow.MetaGraphDef" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ASSET_FILE_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "COLLECTION_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "CollectionDefEntry" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "GRAPH_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "META_INFO_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "MetaInfoDef" + mtype: "" + } + member { + name: "SAVER_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "SIGNATURE_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "SignatureDefEntry" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt new file mode 100644 index 00000000000..2750bd780ca --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.NameAttrList.AttrEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt new file mode 100644 index 00000000000..d10faf67d02 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt @@ -0,0 +1,88 @@ +path: "tensorflow.NameAttrList" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ATTR_FIELD_NUMBER" + mtype: "" + } + member { + name: "AttrEntry" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "NAME_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt new file mode 100644 index 00000000000..b1b62d60f1e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.NodeDef.AttrEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt new file mode 100644 index 00000000000..b812b4df2b3 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt @@ -0,0 +1,100 @@ +path: "tensorflow.NodeDef" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ATTR_FIELD_NUMBER" + mtype: "" + } + member { + name: "AttrEntry" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "DEVICE_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "INPUT_FIELD_NUMBER" + mtype: "" + } + member { + name: "NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "OP_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-op-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.-op-error.pbtxt new file mode 100644 index 00000000000..7e59615534f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-op-error.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.OpError" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt b/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt new file mode 100644 index 00000000000..64240f70698 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt @@ -0,0 +1,69 @@ +path: "tensorflow.Operation" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "control_inputs" + mtype: "" + } + member { + name: "device" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inputs" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op_def" + mtype: "" + } + member { + name: "outputs" + mtype: "" + } + member { + name: "traceback" + mtype: "" + } + member { + name: "traceback_with_start_lines" + mtype: "" + } + member { + name: "type" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'g\', \'inputs\', \'output_types\', \'control_inputs\', \'input_types\', \'original_op\', \'op_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "colocation_groups" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_attr" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run" + argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "values" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt new file mode 100644 index 00000000000..5dd1ee47c96 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt @@ -0,0 +1,128 @@ +path: "tensorflow.OptimizerOptions" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DEFAULT" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "DO_COMMON_SUBEXPRESSION_ELIMINATION_FIELD_NUMBER" + mtype: "" + } + member { + name: "DO_CONSTANT_FOLDING_FIELD_NUMBER" + mtype: "" + } + member { + name: "DO_FUNCTION_INLINING_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "GLOBAL_JIT_LEVEL_FIELD_NUMBER" + mtype: "" + } + member { + name: "GlobalJitLevel" + mtype: "" + } + member { + name: "L0" + mtype: "" + } + member { + name: "L1" + mtype: "" + } + member { + name: "Level" + mtype: "" + } + member { + name: "OFF" + mtype: "" + } + member { + name: "ON_1" + mtype: "" + } + member { + name: "ON_2" + mtype: "" + } + member { + name: "OPT_LEVEL_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt new file mode 100644 index 00000000000..1bfe723ce75 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt @@ -0,0 +1,62 @@ +path: "tensorflow.PaddingFIFOQueue" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'capacity\', \'dtypes\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'padding_fifo_queue\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt new file mode 100644 index 00000000000..dbe25f3a5b9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt @@ -0,0 +1,62 @@ +path: "tensorflow.PriorityQueue" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'capacity\', \'types\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'priority_queue\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt b/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt new file mode 100644 index 00000000000..9263d73a511 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt @@ -0,0 +1,61 @@ +path: "tensorflow.QueueBase" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtypes\', \'shapes\', \'names\', \'queue_ref\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt new file mode 100644 index 00000000000..ec783ffe5a0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt @@ -0,0 +1,62 @@ +path: "tensorflow.RandomShuffleQueue" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'capacity\', \'min_after_dequeue\', \'dtypes\', \'shapes\', \'names\', \'seed\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'random_shuffle_queue\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-reader-base.pbtxt b/tensorflow/tools/api/golden/tensorflow.-reader-base.pbtxt new file mode 100644 index 00000000000..f6a3ce76a15 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-reader-base.pbtxt @@ -0,0 +1,45 @@ +path: "tensorflow.ReaderBase" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'reader_ref\', \'supports_serialize\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-register-gradient.pbtxt b/tensorflow/tools/api/golden/tensorflow.-register-gradient.pbtxt new file mode 100644 index 00000000000..4d6e4137d12 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-register-gradient.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.RegisterGradient" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'op_type\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt new file mode 100644 index 00000000000..808fa0fa217 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt @@ -0,0 +1,88 @@ +path: "tensorflow.RunMetadata" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "COST_GRAPH_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "PARTITION_GRAPHS_FIELD_NUMBER" + mtype: "" + } + member { + name: "STEP_STATS_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt new file mode 100644 index 00000000000..5ad6804a78c --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt @@ -0,0 +1,116 @@ +path: "tensorflow.RunOptions" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DEBUG_OPTIONS_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FULL_TRACE" + mtype: "" + } + member { + name: "HARDWARE_TRACE" + mtype: "" + } + member { + name: "INTER_OP_THREAD_POOL_FIELD_NUMBER" + mtype: "" + } + member { + name: "NO_TRACE" + mtype: "" + } + member { + name: "OUTPUT_PARTITION_GRAPHS_FIELD_NUMBER" + mtype: "" + } + member { + name: "SOFTWARE_TRACE" + mtype: "" + } + member { + name: "TIMEOUT_IN_MS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TRACE_LEVEL_FIELD_NUMBER" + mtype: "" + } + member { + name: "TraceLevel" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt b/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt new file mode 100644 index 00000000000..ec66d7f3354 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt @@ -0,0 +1,108 @@ +path: "tensorflow.SessionLog" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "CHECKPOINT" + mtype: "" + } + member { + name: "CHECKPOINT_PATH_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "MSG_FIELD_NUMBER" + mtype: "" + } + member { + name: "START" + mtype: "" + } + member { + name: "STATUS_FIELD_NUMBER" + mtype: "" + } + member { + name: "STATUS_UNSPECIFIED" + mtype: "" + } + member { + name: "STOP" + mtype: "" + } + member { + name: "SessionStatus" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.-session.pbtxt new file mode 100644 index 00000000000..173cd1963e5 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-session.pbtxt @@ -0,0 +1,55 @@ +path: "tensorflow.Session" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "graph_def" + mtype: "" + } + member { + name: "sess_str" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'target\', \'graph\', \'config\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'None\'], " + } + member_method { + name: "as_default" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "list_devices" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_callable" + argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "partial_run" + argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "partial_run_setup" + argspec: "args=[\'self\', \'fetches\', \'feeds\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'target\', \'containers\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "run" + argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/tensorflow.-sparse-conditional-accumulator.pbtxt new file mode 100644 index 00000000000..2260279ad2b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-sparse-conditional-accumulator.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.SparseConditionalAccumulator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "accumulator_ref" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], " + } + member_method { + name: "apply_grad" + argspec: "args=[\'self\', \'grad_indices\', \'grad_values\', \'grad_shape\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "apply_indexed_slices_grad" + argspec: "args=[\'self\', \'grad\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " + } + member_method { + name: "num_accumulated" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_global_step" + argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "take_grad" + argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "take_indexed_slices_grad" + argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-sparse-feature.pbtxt b/tensorflow/tools/api/golden/tensorflow.-sparse-feature.pbtxt new file mode 100644 index 00000000000..d875394fb5d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-sparse-feature.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.SparseFeature" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "already_sorted" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "index_key" + mtype: "" + } + member { + name: "size" + mtype: "" + } + member { + name: "value_key" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-sparse-tensor-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-sparse-tensor-value.pbtxt new file mode 100644 index 00000000000..d33fd4d5d7b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-sparse-tensor-value.pbtxt @@ -0,0 +1,26 @@ +path: "tensorflow.SparseTensorValue" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dense_shape" + mtype: "" + } + member { + name: "indices" + mtype: "" + } + member { + name: "values" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/tensorflow.-sparse-tensor.pbtxt new file mode 100644 index 00000000000..eac236d4982 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-sparse-tensor.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.SparseTensor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dense_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "indices" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member { + name: "values" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "eval" + argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "from_value" + argspec: "args=[\'cls\', \'sparse_tensor_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_shape" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt new file mode 100644 index 00000000000..781010d75e2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt @@ -0,0 +1,96 @@ +path: "tensorflow.Summary.Audio" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "CONTENT_TYPE_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "ENCODED_AUDIO_STRING_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "LENGTH_FRAMES_FIELD_NUMBER" + mtype: "" + } + member { + name: "NUM_CHANNELS_FIELD_NUMBER" + mtype: "" + } + member { + name: "SAMPLE_RATE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt new file mode 100644 index 00000000000..feb9c7ee927 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt @@ -0,0 +1,92 @@ +path: "tensorflow.Summary.Image" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "COLORSPACE_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "ENCODED_IMAGE_STRING_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "HEIGHT_FIELD_NUMBER" + mtype: "" + } + member { + name: "WIDTH_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt new file mode 100644 index 00000000000..ffb4f45fc5e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt @@ -0,0 +1,112 @@ +path: "tensorflow.Summary.Value" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "AUDIO_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "HISTO_FIELD_NUMBER" + mtype: "" + } + member { + name: "IMAGE_FIELD_NUMBER" + mtype: "" + } + member { + name: "METADATA_FIELD_NUMBER" + mtype: "" + } + member { + name: "NODE_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "OBSOLETE_OLD_STYLE_HISTOGRAM_FIELD_NUMBER" + mtype: "" + } + member { + name: "SIMPLE_VALUE_FIELD_NUMBER" + mtype: "" + } + member { + name: "TAG_FIELD_NUMBER" + mtype: "" + } + member { + name: "TENSOR_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt new file mode 100644 index 00000000000..38de17fa9e5 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt @@ -0,0 +1,92 @@ +path: "tensorflow.Summary" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "Audio" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "Image" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member { + name: "Value" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-t-f-record-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-t-f-record-reader.pbtxt new file mode 100644 index 00000000000..cdf79373919 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-t-f-record-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.TFRecordReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-array.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-array.pbtxt new file mode 100644 index 00000000000..ed088c41ed3 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-array.pbtxt @@ -0,0 +1,69 @@ +path: "tensorflow.TensorArray" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dtype" + mtype: "" + } + member { + name: "flow" + mtype: "" + } + member { + name: "handle" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'size\', \'dynamic_size\', \'clear_after_read\', \'tensor_array_name\', \'handle\', \'flow\', \'infer_shape\', \'element_shape\', \'colocate_with_first_write_call\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "concat" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "gather" + argspec: "args=[\'self\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "grad" + argspec: "args=[\'self\', \'source\', \'flow\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "identity" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read" + argspec: "args=[\'self\', \'index\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "scatter" + argspec: "args=[\'self\', \'indices\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "split" + argspec: "args=[\'self\', \'value\', \'lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "stack" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unstack" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "write" + argspec: "args=[\'self\', \'index\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt new file mode 100644 index 00000000000..425c35e0674 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt @@ -0,0 +1,88 @@ +path: "tensorflow.TensorInfo.CooSparse" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DENSE_SHAPE_TENSOR_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "INDICES_TENSOR_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUES_TENSOR_NAME_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt new file mode 100644 index 00000000000..41ea393be51 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt @@ -0,0 +1,96 @@ +path: "tensorflow.TensorInfo" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "COO_SPARSE_FIELD_NUMBER" + mtype: "" + } + member { + name: "CooSparse" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "DTYPE_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "TENSOR_SHAPE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt new file mode 100644 index 00000000000..d5b9cb8f5ed --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt @@ -0,0 +1,73 @@ +path: "tensorflow.TensorShape" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dims" + mtype: "" + } + member { + name: "ndims" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dims\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_list" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_proto" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_has_rank" + argspec: "args=[\'self\', \'rank\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_is_fully_defined" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_same_rank" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "concatenate" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_fully_defined" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "num_elements" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_rank" + argspec: "args=[\'self\', \'rank\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_rank_at_least" + argspec: "args=[\'self\', \'rank\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_rank_at_most" + argspec: "args=[\'self\', \'rank\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor.pbtxt new file mode 100644 index 00000000000..38d19bb5374 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-tensor.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.Tensor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "OVERLOADABLE_OPERATORS" + mtype: "" + } + member { + name: "device" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "value_index" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'op\', \'value_index\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "consumers" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "eval" + argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_shape" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_shape" + argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-text-line-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-text-line-reader.pbtxt new file mode 100644 index 00000000000..e9779f07620 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-text-line-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.TextLineReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'skip_header_lines\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-var-len-feature.pbtxt b/tensorflow/tools/api/golden/tensorflow.-var-len-feature.pbtxt new file mode 100644 index 00000000000..54b66f43f8e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-var-len-feature.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.VarLenFeature" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtype" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt new file mode 100644 index 00000000000..c9b2dfd6772 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt @@ -0,0 +1,97 @@ +path: "tensorflow.VariableScope" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "caching_device" + mtype: "" + } + member { + name: "custom_getter" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "initializer" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "original_name_scope" + mtype: "" + } + member { + name: "partitioner" + mtype: "" + } + member { + name: "regularizer" + mtype: "" + } + member { + name: "reuse" + mtype: "" + } + member { + name: "use_resource" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'reuse\', \'name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'name_scope\', \'dtype\', \'use_resource\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'\', \"\", \'None\'], " + } + member_method { + name: "get_collection" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable" + argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "global_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reuse_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_caching_device" + argspec: "args=[\'self\', \'caching_device\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_custom_getter" + argspec: "args=[\'self\', \'custom_getter\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_dtype" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_initializer" + argspec: "args=[\'self\', \'initializer\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_partitioner" + argspec: "args=[\'self\', \'partitioner\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_regularizer" + argspec: "args=[\'self\', \'regularizer\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_use_resource" + argspec: "args=[\'self\', \'use_resource\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "trainable_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.-save-slice-info.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.-save-slice-info.pbtxt new file mode 100644 index 00000000000..ac3ccd468b2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-variable.-save-slice-info.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.Variable.SaveSliceInfo" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "spec" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'full_name\', \'full_shape\', \'var_offset\', \'var_shape\', \'save_slice_info_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt new file mode 100644 index 00000000000..d67a2713f7a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt @@ -0,0 +1,101 @@ +path: "tensorflow.Variable" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "SaveSliceInfo" + mtype: "" + } + member { + name: "device" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "initial_value" + mtype: "" + } + member { + name: "initializer" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assign" + argspec: "args=[\'self\', \'value\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "assign_add" + argspec: "args=[\'self\', \'delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "assign_sub" + argspec: "args=[\'self\', \'delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "count_up_to" + argspec: "args=[\'self\', \'limit\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "eval" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_proto" + argspec: "args=[\'variable_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_shape" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "initialized_value" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load" + argspec: "args=[\'self\', \'value\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_value" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "scatter_sub" + argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "set_shape" + argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "value" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-whole-file-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-whole-file-reader.pbtxt new file mode 100644 index 00000000000..4ac759891c6 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-whole-file-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.WholeFileReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.app.pbtxt b/tensorflow/tools/api/golden/tensorflow.app.pbtxt new file mode 100644 index 00000000000..85044a89879 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.app.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.app" +tf_module { + member { + name: "flags" + mtype: "" + } + member_method { + name: "run" + argspec: "args=[\'main\', \'argv\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.compat.pbtxt b/tensorflow/tools/api/golden/tensorflow.compat.pbtxt new file mode 100644 index 00000000000..ccc60314001 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.compat.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.compat" +tf_module { + member { + name: "bytes_or_text_types" + mtype: "" + } + member { + name: "complex_types" + mtype: "" + } + member { + name: "integral_types" + mtype: "" + } + member { + name: "real_types" + mtype: "" + } + member_method { + name: "as_bytes" + argspec: "args=[\'bytes_or_text\', \'encoding\'], varargs=None, keywords=None, defaults=[\'utf-8\'], " + } + member_method { + name: "as_str" + argspec: "args=[\'bytes_or_text\', \'encoding\'], varargs=None, keywords=None, defaults=[\'utf-8\'], " + } + member_method { + name: "as_str_any" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_text" + argspec: "args=[\'bytes_or_text\', \'encoding\'], varargs=None, keywords=None, defaults=[\'utf-8\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt new file mode 100644 index 00000000000..00ec669b168 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.constant_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'False\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-aborted-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-aborted-error.pbtxt new file mode 100644 index 00000000000..ea9186b0b9d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-aborted-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.AbortedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-already-exists-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-already-exists-error.pbtxt new file mode 100644 index 00000000000..4e155081dd2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-already-exists-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.AlreadyExistsError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-cancelled-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-cancelled-error.pbtxt new file mode 100644 index 00000000000..b02a0e023aa --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-cancelled-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.CancelledError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-data-loss-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-data-loss-error.pbtxt new file mode 100644 index 00000000000..c1fa66342a7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-data-loss-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.DataLossError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-deadline-exceeded-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-deadline-exceeded-error.pbtxt new file mode 100644 index 00000000000..8e037936191 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-deadline-exceeded-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.DeadlineExceededError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-failed-precondition-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-failed-precondition-error.pbtxt new file mode 100644 index 00000000000..384d4b534c6 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-failed-precondition-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.FailedPreconditionError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-internal-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-internal-error.pbtxt new file mode 100644 index 00000000000..ac5c4d7879b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-internal-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.InternalError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-invalid-argument-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-invalid-argument-error.pbtxt new file mode 100644 index 00000000000..161edd4a7c5 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-invalid-argument-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.InvalidArgumentError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-not-found-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-not-found-error.pbtxt new file mode 100644 index 00000000000..1e64730ac6d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-not-found-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.NotFoundError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-op-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-op-error.pbtxt new file mode 100644 index 00000000000..b1f14c0457d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-op-error.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.errors.OpError" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-out-of-range-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-out-of-range-error.pbtxt new file mode 100644 index 00000000000..6365e472868 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-out-of-range-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.OutOfRangeError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-permission-denied-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-permission-denied-error.pbtxt new file mode 100644 index 00000000000..dc8a66f9ead --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-permission-denied-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.PermissionDeniedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-resource-exhausted-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-resource-exhausted-error.pbtxt new file mode 100644 index 00000000000..85bb384b469 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-resource-exhausted-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.ResourceExhaustedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-unauthenticated-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-unauthenticated-error.pbtxt new file mode 100644 index 00000000000..d57d7ac2f20 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-unauthenticated-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.UnauthenticatedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-unavailable-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-unavailable-error.pbtxt new file mode 100644 index 00000000000..cc33e6ed8d1 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-unavailable-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.UnavailableError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-unimplemented-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-unimplemented-error.pbtxt new file mode 100644 index 00000000000..b8c2e22dbd7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-unimplemented-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.UnimplementedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-unknown-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.-unknown-error.pbtxt new file mode 100644 index 00000000000..8ffcfae95b8 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.-unknown-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.UnknownError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=[\'2\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.pbtxt b/tensorflow/tools/api/golden/tensorflow.errors.pbtxt new file mode 100644 index 00000000000..0ad1c19603b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.errors.pbtxt @@ -0,0 +1,151 @@ +path: "tensorflow.errors" +tf_module { + member { + name: "ABORTED" + mtype: "" + } + member { + name: "ALREADY_EXISTS" + mtype: "" + } + member { + name: "AbortedError" + mtype: "" + } + member { + name: "AlreadyExistsError" + mtype: "" + } + member { + name: "CANCELLED" + mtype: "" + } + member { + name: "CancelledError" + mtype: "" + } + member { + name: "DATA_LOSS" + mtype: "" + } + member { + name: "DEADLINE_EXCEEDED" + mtype: "" + } + member { + name: "DataLossError" + mtype: "" + } + member { + name: "DeadlineExceededError" + mtype: "" + } + member { + name: "FAILED_PRECONDITION" + mtype: "" + } + member { + name: "FailedPreconditionError" + mtype: "" + } + member { + name: "INTERNAL" + mtype: "" + } + member { + name: "INVALID_ARGUMENT" + mtype: "" + } + member { + name: "InternalError" + mtype: "" + } + member { + name: "InvalidArgumentError" + mtype: "" + } + member { + name: "NOT_FOUND" + mtype: "" + } + member { + name: "NotFoundError" + mtype: "" + } + member { + name: "OK" + mtype: "" + } + member { + name: "OUT_OF_RANGE" + mtype: "" + } + member { + name: "OpError" + mtype: "" + } + member { + name: "OutOfRangeError" + mtype: "" + } + member { + name: "PERMISSION_DENIED" + mtype: "" + } + member { + name: "PermissionDeniedError" + mtype: "" + } + member { + name: "RESOURCE_EXHAUSTED" + mtype: "" + } + member { + name: "ResourceExhaustedError" + mtype: "" + } + member { + name: "UNAUTHENTICATED" + mtype: "" + } + member { + name: "UNAVAILABLE" + mtype: "" + } + member { + name: "UNIMPLEMENTED" + mtype: "" + } + member { + name: "UNKNOWN" + mtype: "" + } + member { + name: "UnauthenticatedError" + mtype: "" + } + member { + name: "UnavailableError" + mtype: "" + } + member { + name: "UnimplementedError" + mtype: "" + } + member { + name: "UnknownError" + mtype: "" + } + member_method { + name: "error_code_from_exception_type" + argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "exception_type_from_error_code" + argspec: "args=[\'error_code\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "raise_exception_on_not_ok_status" + argspec: "args=[], varargs=args, keywords=kwds, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt new file mode 100644 index 00000000000..5dbfe217264 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt @@ -0,0 +1,47 @@ +path: "tensorflow.estimator.EstimatorSpec" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "eval_metric_ops" + mtype: "" + } + member { + name: "export_outputs" + mtype: "" + } + member { + name: "loss" + mtype: "" + } + member { + name: "predictions" + mtype: "" + } + member { + name: "scaffold" + mtype: "" + } + member { + name: "train_op" + mtype: "" + } + member { + name: "training_chief_hooks" + mtype: "" + } + member { + name: "training_hooks" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt new file mode 100644 index 00000000000..7a769fd546c --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt @@ -0,0 +1,37 @@ +path: "tensorflow.estimator.Estimator" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_fn\', \'model_dir\', \'config\', \'params\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-mode-keys.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-mode-keys.pbtxt new file mode 100644 index 00000000000..6a1c24fa63f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-mode-keys.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.estimator.ModeKeys" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "EVAL" + mtype: "" + } + member { + name: "PREDICT" + mtype: "" + } + member { + name: "TRAIN" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt new file mode 100644 index 00000000000..d69c475a313 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt @@ -0,0 +1,77 @@ +path: "tensorflow.estimator.RunConfig" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "cluster_spec" + mtype: "" + } + member { + name: "evaluation_master" + mtype: "" + } + member { + name: "is_chief" + mtype: "" + } + member { + name: "keep_checkpoint_every_n_hours" + mtype: "" + } + member { + name: "keep_checkpoint_max" + mtype: "" + } + member { + name: "master" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "num_ps_replicas" + mtype: "" + } + member { + name: "num_worker_replicas" + mtype: "" + } + member { + name: "save_checkpoints_secs" + mtype: "" + } + member { + name: "save_checkpoints_steps" + mtype: "" + } + member { + name: "save_summary_steps" + mtype: "" + } + member { + name: "session_config" + mtype: "" + } + member { + name: "task_id" + mtype: "" + } + member { + name: "task_type" + mtype: "" + } + member { + name: "tf_random_seed" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "replace" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt new file mode 100644 index 00000000000..3cf7af8da95 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.estimator.export.ClassificationOutput.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.pbtxt new file mode 100644 index 00000000000..2df1840c4a4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.estimator.export.ClassificationOutput" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "classes" + mtype: "" + } + member { + name: "scores" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scores\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "as_signature_def" + argspec: "args=[\'self\', \'receiver_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt new file mode 100644 index 00000000000..5d165ccbf91 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.estimator.export.ExportOutput.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.pbtxt new file mode 100644 index 00000000000..fa62e8ced80 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.estimator.export.ExportOutput" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "as_signature_def" + argspec: "args=[\'self\', \'receiver_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt new file mode 100644 index 00000000000..743495ba98c --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.estimator.export.PredictOutput.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.pbtxt new file mode 100644 index 00000000000..e0160b10ce1 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.estimator.export.PredictOutput" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "outputs" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'outputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_signature_def" + argspec: "args=[\'self\', \'receiver_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt new file mode 100644 index 00000000000..dbf4e3dec85 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.estimator.export.RegressionOutput.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.pbtxt new file mode 100644 index 00000000000..905f0e05535 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.estimator.export.RegressionOutput" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "value" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_signature_def" + argspec: "args=[\'self\', \'receiver_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-serving-input-receiver.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.-serving-input-receiver.pbtxt new file mode 100644 index 00000000000..0d9e0443088 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.-serving-input-receiver.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.estimator.export.ServingInputReceiver" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "features" + mtype: "" + } + member { + name: "receiver_tensors" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.export.pbtxt new file mode 100644 index 00000000000..4d0dddb3bc0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.export.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.estimator.export" +tf_module { + member { + name: "ClassificationOutput" + mtype: "" + } + member { + name: "ExportOutput" + mtype: "" + } + member { + name: "PredictOutput" + mtype: "" + } + member { + name: "RegressionOutput" + mtype: "" + } + member { + name: "ServingInputReceiver" + mtype: "" + } + member_method { + name: "build_parsing_serving_input_receiver_fn" + argspec: "args=[\'feature_spec\', \'default_batch_size\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "build_raw_serving_input_receiver_fn" + argspec: "args=[\'features\', \'default_batch_size\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.inputs.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.inputs.pbtxt new file mode 100644 index 00000000000..b318fea1f82 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.inputs.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.estimator.inputs" +tf_module { + member_method { + name: "numpy_input_fn" + argspec: "args=[\'x\', \'y\', \'batch_size\', \'num_epochs\', \'shuffle\', \'queue_capacity\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'128\', \'1\', \'None\', \'1000\', \'1\'], " + } + member_method { + name: "pandas_input_fn" + argspec: "args=[\'x\', \'y\', \'batch_size\', \'num_epochs\', \'shuffle\', \'queue_capacity\', \'num_threads\', \'target_column\'], varargs=None, keywords=None, defaults=[\'None\', \'128\', \'1\', \'None\', \'1000\', \'1\', \'target\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt new file mode 100644 index 00000000000..0d5dc73271d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.estimator" +tf_module { + member { + name: "Estimator" + mtype: "" + } + member { + name: "EstimatorSpec" + mtype: "" + } + member { + name: "ModeKeys" + mtype: "" + } + member { + name: "RunConfig" + mtype: "" + } + member { + name: "export" + mtype: "" + } + member { + name: "inputs" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt new file mode 100644 index 00000000000..4c633a850f8 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt @@ -0,0 +1,55 @@ +path: "tensorflow.feature_column" +tf_module { + member_method { + name: "bucketized_column" + argspec: "args=[\'source_column\', \'boundaries\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "categorical_column_with_hash_bucket" + argspec: "args=[\'key\', \'hash_bucket_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "categorical_column_with_identity" + argspec: "args=[\'key\', \'num_buckets\', \'default_value\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "categorical_column_with_vocabulary_file" + argspec: "args=[\'key\', \'vocabulary_file\', \'vocabulary_size\', \'num_oov_buckets\', \'default_value\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \"\"], " + } + member_method { + name: "categorical_column_with_vocabulary_list" + argspec: "args=[\'key\', \'vocabulary_list\', \'dtype\', \'default_value\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " + } + member_method { + name: "crossed_column" + argspec: "args=[\'keys\', \'hash_bucket_size\', \'hash_key\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "embedding_column" + argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'ckpt_to_load_from\', \'tensor_name_in_ckpt\', \'max_norm\', \'trainable\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "indicator_column" + argspec: "args=[\'categorical_column\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "input_layer" + argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "linear_model" + argspec: "args=[\'features\', \'feature_columns\', \'units\', \'sparse_combiner\', \'weight_collections\', \'trainable\'], varargs=None, keywords=None, defaults=[\'1\', \'sum\', \'None\', \'True\'], " + } + member_method { + name: "make_parse_example_spec" + argspec: "args=[\'feature_columns\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "numeric_column" + argspec: "args=[\'key\', \'shape\', \'default_value\', \'dtype\', \'normalizer_fn\'], varargs=None, keywords=None, defaults=[\'(1,)\', \'None\', \"\", \'None\'], " + } + member_method { + name: "weighted_categorical_column" + argspec: "args=[\'categorical_column\', \'weight_feature_key\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.-fast-g-file.pbtxt b/tensorflow/tools/api/golden/tensorflow.gfile.-fast-g-file.pbtxt new file mode 100644 index 00000000000..eecfaffd0a6 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.gfile.-fast-g-file.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.gfile.FastGFile" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "mode" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'mode\'], varargs=None, keywords=None, defaults=[\'r\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "next" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read" + argspec: "args=[\'self\', \'n\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "readline" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "readlines" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "seek" + argspec: "args=[\'self\', \'offset\', \'whence\', \'position\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tell" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "write" + argspec: "args=[\'self\', \'file_content\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.-g-file.pbtxt b/tensorflow/tools/api/golden/tensorflow.gfile.-g-file.pbtxt new file mode 100644 index 00000000000..305251059d9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.gfile.-g-file.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.gfile.GFile" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "mode" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'mode\'], varargs=None, keywords=None, defaults=[\'r\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "next" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read" + argspec: "args=[\'self\', \'n\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "readline" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "readlines" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "seek" + argspec: "args=[\'self\', \'offset\', \'whence\', \'position\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tell" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "write" + argspec: "args=[\'self\', \'file_content\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.-open.pbtxt b/tensorflow/tools/api/golden/tensorflow.gfile.-open.pbtxt new file mode 100644 index 00000000000..6e8894180a4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.gfile.-open.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.gfile.Open" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "mode" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'mode\'], varargs=None, keywords=None, defaults=[\'r\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "next" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read" + argspec: "args=[\'self\', \'n\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "readline" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "readlines" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "seek" + argspec: "args=[\'self\', \'offset\', \'whence\', \'position\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tell" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "write" + argspec: "args=[\'self\', \'file_content\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.pbtxt b/tensorflow/tools/api/golden/tensorflow.gfile.pbtxt new file mode 100644 index 00000000000..65b55a8b7c4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.gfile.pbtxt @@ -0,0 +1,63 @@ +path: "tensorflow.gfile" +tf_module { + member { + name: "FastGFile" + mtype: "" + } + member { + name: "GFile" + mtype: "" + } + member { + name: "Open" + mtype: "" + } + member_method { + name: "Copy" + argspec: "args=[\'oldpath\', \'newpath\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "DeleteRecursively" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Exists" + argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Glob" + argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "IsDirectory" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ListDirectory" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MakeDirs" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MkDir" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Remove" + argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Rename" + argspec: "args=[\'oldname\', \'newname\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "Stat" + argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Walk" + argspec: "args=[\'top\', \'in_order\'], varargs=None, keywords=None, defaults=[\'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.graph_util.pbtxt b/tensorflow/tools/api/golden/tensorflow.graph_util.pbtxt new file mode 100644 index 00000000000..eeabf845dca --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.graph_util.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.graph_util" +tf_module { + member_method { + name: "convert_variables_to_constants" + argspec: "args=[\'sess\', \'input_graph_def\', \'output_node_names\', \'variable_names_whitelist\', \'variable_names_blacklist\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "extract_sub_graph" + argspec: "args=[\'graph_def\', \'dest_nodes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "must_run_on_cpu" + argspec: "args=[\'node\', \'pin_variables_on_cpu\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "remove_training_nodes" + argspec: "args=[\'input_graph\', \'protected_nodes\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tensor_shape_from_node_def_name" + argspec: "args=[\'graph\', \'input_name\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.image.-resize-method.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.-resize-method.pbtxt new file mode 100644 index 00000000000..dbc360b13ee --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.image.-resize-method.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.image.ResizeMethod" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "AREA" + mtype: "" + } + member { + name: "BICUBIC" + mtype: "" + } + member { + name: "BILINEAR" + mtype: "" + } + member { + name: "NEAREST_NEIGHBOR" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt new file mode 100644 index 00000000000..93257c84a1f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -0,0 +1,183 @@ +path: "tensorflow.image" +tf_module { + member { + name: "ResizeMethod" + mtype: "" + } + member_method { + name: "adjust_brightness" + argspec: "args=[\'image\', \'delta\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "adjust_contrast" + argspec: "args=[\'images\', \'contrast_factor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "adjust_gamma" + argspec: "args=[\'image\', \'gamma\', \'gain\'], varargs=None, keywords=None, defaults=[\'1\', \'1\'], " + } + member_method { + name: "adjust_hue" + argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "adjust_saturation" + argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "central_crop" + argspec: "args=[\'image\', \'central_fraction\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "convert_image_dtype" + argspec: "args=[\'image\', \'dtype\', \'saturate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "crop_and_resize" + argspec: "args=[\'image\', \'boxes\', \'box_ind\', \'crop_size\', \'method\', \'extrapolation_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "crop_to_bounding_box" + argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "decode_bmp" + argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "decode_gif" + argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_image" + argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "decode_jpeg" + argspec: "args=[\'contents\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "decode_png" + argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "draw_bounding_boxes" + argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "encode_jpeg" + argspec: "args=[\'image\', \'format\', \'quality\', \'progressive\', \'optimize_size\', \'chroma_downsampling\', \'density_unit\', \'x_density\', \'y_density\', \'xmp_metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "encode_png" + argspec: "args=[\'image\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "extract_glimpse" + argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "flip_left_right" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flip_up_down" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "grayscale_to_rgb" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "hsv_to_rgb" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "non_max_suppression" + argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "non_max_suppression_v2" + argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "pad_to_bounding_box" + argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "per_image_standardization" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "random_brightness" + argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_contrast" + argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_flip_left_right" + argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_flip_up_down" + argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_hue" + argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_saturation" + argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "resize_area" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "resize_bicubic" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "resize_bilinear" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "resize_image_with_crop_or_pad" + argspec: "args=[\'image\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "resize_images" + argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\'], varargs=None, keywords=None, defaults=[\'0\', \'False\'], " + } + member_method { + name: "resize_nearest_neighbor" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "rgb_to_grayscale" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rgb_to_hsv" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rot90" + argspec: "args=[\'image\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " + } + member_method { + name: "sample_distorted_bounding_box" + argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "total_variation" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "transpose_image" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt new file mode 100644 index 00000000000..418ca3ea466 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt @@ -0,0 +1,63 @@ +path: "tensorflow.layers" +tf_module { + member_method { + name: "average_pooling1d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "average_pooling2d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "average_pooling3d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "batch_normalization" + argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'\', \'\', \'\', \'\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'False\'], " + } + member_method { + name: "conv1d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "conv2d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'(1, 1)\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "conv2d_transpose" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "conv3d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'channels_last\', \'(1, 1, 1)\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "conv3d_transpose" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'channels_last\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "dense" + argspec: "args=[\'inputs\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "dropout" + argspec: "args=[\'inputs\', \'rate\', \'noise_shape\', \'seed\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "max_pooling1d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "max_pooling2d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "max_pooling3d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "separable_conv2d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'(1, 1)\', \'1\', \'None\', \'True\', \'None\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.logging.pbtxt b/tensorflow/tools/api/golden/tensorflow.logging.pbtxt new file mode 100644 index 00000000000..85bb15455da --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.logging.pbtxt @@ -0,0 +1,83 @@ +path: "tensorflow.logging" +tf_module { + member { + name: "DEBUG" + mtype: "" + } + member { + name: "ERROR" + mtype: "" + } + member { + name: "FATAL" + mtype: "" + } + member { + name: "INFO" + mtype: "" + } + member { + name: "WARN" + mtype: "" + } + member_method { + name: "TaskLevelStatusMessage" + argspec: "args=[\'msg\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "debug" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "error" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "fatal" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_verbosity" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "info" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "log" + argspec: "args=[\'level\', \'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "log_every_n" + argspec: "args=[\'level\', \'msg\', \'n\'], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "log_first_n" + argspec: "args=[\'level\', \'msg\', \'n\'], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "log_if" + argspec: "args=[\'level\', \'msg\', \'condition\'], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "set_verbosity" + argspec: "args=[\'v\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "vlog" + argspec: "args=[\'level\', \'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "warn" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "warning" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/tensorflow.losses.-reduction.pbtxt new file mode 100644 index 00000000000..4bdc73370bf --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.losses.-reduction.pbtxt @@ -0,0 +1,32 @@ +path: "tensorflow.losses.Reduction" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "MEAN" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "SUM" + mtype: "" + } + member { + name: "SUM_BY_NONZERO_WEIGHTS" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "all" + argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "validate" + argspec: "args=[\'cls\', \'key\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.losses.pbtxt b/tensorflow/tools/api/golden/tensorflow.losses.pbtxt new file mode 100644 index 00000000000..79443839b9a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.losses.pbtxt @@ -0,0 +1,71 @@ +path: "tensorflow.losses" +tf_module { + member { + name: "Reduction" + mtype: "" + } + member_method { + name: "absolute_difference" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'loss\', \'loss_collection\'], varargs=None, keywords=None, defaults=[\'losses\'], " + } + member_method { + name: "compute_weighted_loss" + argspec: "args=[\'losses\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "cosine_distance" + argspec: "args=[\'labels\', \'predictions\', \'dim\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "get_losses" + argspec: "args=[\'scope\', \'loss_collection\'], varargs=None, keywords=None, defaults=[\'None\', \'losses\'], " + } + member_method { + name: "get_regularization_loss" + argspec: "args=[\'scope\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'total_regularization_loss\'], " + } + member_method { + name: "get_regularization_losses" + argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_total_loss" + argspec: "args=[\'add_regularization_losses\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'total_loss\'], " + } + member_method { + name: "hinge_loss" + argspec: "args=[\'labels\', \'logits\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "huber_loss" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'delta\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "log_loss" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'epsilon\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1e-07\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "mean_pairwise_squared_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'scope\', \'loss_collection\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\'], " + } + member_method { + name: "mean_squared_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "sigmoid_cross_entropy" + argspec: "args=[\'multi_class_labels\', \'logits\', \'weights\', \'label_smoothing\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "softmax_cross_entropy" + argspec: "args=[\'onehot_labels\', \'logits\', \'weights\', \'label_smoothing\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "sparse_softmax_cross_entropy" + argspec: "args=[\'labels\', \'logits\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt new file mode 100644 index 00000000000..262d11c38e1 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt @@ -0,0 +1,99 @@ +path: "tensorflow.metrics" +tf_module { + member_method { + name: "accuracy" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "auc" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'curve\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'ROC\', \'None\'], " + } + member_method { + name: "false_negatives" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "false_positives" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean" + argspec: "args=[\'values\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_absolute_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_cosine_distance" + argspec: "args=[\'labels\', \'predictions\', \'dim\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_iou" + argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_per_class_accuracy" + argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_relative_error" + argspec: "args=[\'labels\', \'predictions\', \'normalizer\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_squared_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_tensor" + argspec: "args=[\'values\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "percentage_below" + argspec: "args=[\'values\', \'threshold\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "precision" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "precision_at_thresholds" + argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "recall" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "recall_at_k" + argspec: "args=[\'labels\', \'predictions\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "recall_at_thresholds" + argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "root_mean_squared_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sensitivity_at_specificity" + argspec: "args=[\'labels\', \'predictions\', \'specificity\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_average_precision_at_k" + argspec: "args=[\'labels\', \'predictions\', \'k\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_precision_at_k" + argspec: "args=[\'labels\', \'predictions\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "specificity_at_sensitivity" + argspec: "args=[\'labels\', \'predictions\', \'sensitivity\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "true_positives" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt new file mode 100644 index 00000000000..9f817beafd9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt @@ -0,0 +1,339 @@ +path: "tensorflow.nn" +tf_module { + member { + name: "rnn_cell" + mtype: "" + } + member_method { + name: "all_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "atrous_conv2d" + argspec: "args=[\'value\', \'filters\', \'rate\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atrous_conv2d_transpose" + argspec: "args=[\'value\', \'filters\', \'output_shape\', \'rate\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "avg_pool" + argspec: "args=[\'value\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], " + } + member_method { + name: "avg_pool3d" + argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "batch_norm_with_global_normalization" + argspec: "args=[\'t\', \'m\', \'v\', \'beta\', \'gamma\', \'variance_epsilon\', \'scale_after_normalization\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "batch_normalization" + argspec: "args=[\'x\', \'mean\', \'variance\', \'offset\', \'scale\', \'variance_epsilon\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bias_add" + argspec: "args=[\'value\', \'bias\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "bidirectional_dynamic_rnn" + argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'sequence_length\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'parallel_iterations\', \'swap_memory\', \'time_major\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "compute_accidental_hits" + argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "conv1d" + argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "conv2d" + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "conv2d_backprop_filter" + argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "conv2d_backprop_input" + argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "conv2d_transpose" + argspec: "args=[\'value\', \'filter\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'NHWC\', \'None\'], " + } + member_method { + name: "conv3d" + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "conv3d_backprop_filter_v2" + argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "conv3d_transpose" + argspec: "args=[\'value\', \'filter\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'NDHWC\', \'None\'], " + } + member_method { + name: "convolution" + argspec: "args=[\'input\', \'filter\', \'padding\', \'strides\', \'dilation_rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "crelu" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ctc_beam_search_decoder" + argspec: "args=[\'inputs\', \'sequence_length\', \'beam_width\', \'top_paths\', \'merge_repeated\'], varargs=None, keywords=None, defaults=[\'100\', \'1\', \'True\'], " + } + member_method { + name: "ctc_greedy_decoder" + argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "ctc_loss" + argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'True\'], " + } + member_method { + name: "depthwise_conv2d" + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "depthwise_conv2d_native" + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "depthwise_conv2d_native_backprop_filter" + argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "depthwise_conv2d_native_backprop_input" + argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "dilation2d" + argspec: "args=[\'input\', \'filter\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dropout" + argspec: "args=[\'x\', \'keep_prob\', \'noise_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "dynamic_rnn" + argspec: "args=[\'cell\', \'inputs\', \'sequence_length\', \'initial_state\', \'dtype\', \'parallel_iterations\', \'swap_memory\', \'time_major\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "elu" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "embedding_lookup" + argspec: "args=[\'params\', \'ids\', \'partition_strategy\', \'name\', \'validate_indices\', \'max_norm\'], varargs=None, keywords=None, defaults=[\'mod\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "embedding_lookup_sparse" + argspec: "args=[\'params\', \'sp_ids\', \'sp_weights\', \'partition_strategy\', \'name\', \'combiner\', \'max_norm\'], varargs=None, keywords=None, defaults=[\'mod\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "erosion2d" + argspec: "args=[\'value\', \'kernel\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fixed_unigram_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'vocab_file\', \'distortion\', \'num_reserved_ids\', \'num_shards\', \'shard\', \'unigrams\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'1.0\', \'0\', \'1\', \'0\', \'()\', \'None\', \'None\'], " + } + member_method { + name: "fractional_avg_pool" + argspec: "args=[\'value\', \'pooling_ratio\', \'pseudo_random\', \'overlapping\', \'deterministic\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "fractional_max_pool" + argspec: "args=[\'value\', \'pooling_ratio\', \'pseudo_random\', \'overlapping\', \'deterministic\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "fused_batch_norm" + argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\'], " + } + member_method { + name: "in_top_k" + argspec: "args=[\'predictions\', \'targets\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "l2_loss" + argspec: "args=[\'t\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "l2_normalize" + argspec: "args=[\'x\', \'dim\', \'epsilon\', \'name\'], varargs=None, keywords=None, defaults=[\'1e-12\', \'None\'], " + } + member_method { + name: "learned_unigram_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "local_response_normalization" + argspec: "args=[\'input\', \'depth_radius\', \'bias\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "log_poisson_loss" + argspec: "args=[\'targets\', \'log_input\', \'compute_full_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "log_softmax" + argspec: "args=[\'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + } + member_method { + name: "log_uniform_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "lrn" + argspec: "args=[\'input\', \'depth_radius\', \'bias\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "max_pool" + argspec: "args=[\'value\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], " + } + member_method { + name: "max_pool3d" + argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "max_pool_with_argmax" + argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'Targmax\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "moments" + argspec: "args=[\'x\', \'axes\', \'shift\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + } + member_method { + name: "nce_loss" + argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'mod\', \'nce_loss\'], " + } + member_method { + name: "normalize_moments" + argspec: "args=[\'counts\', \'mean_ss\', \'variance_ss\', \'shift\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "pool" + argspec: "args=[\'input\', \'window_shape\', \'pooling_type\', \'padding\', \'dilation_rate\', \'strides\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "quantized_avg_pool" + argspec: "args=[\'input\', \'min_input\', \'max_input\', \'ksize\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "quantized_conv2d" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'strides\', \'padding\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "quantized_max_pool" + argspec: "args=[\'input\', \'min_input\', \'max_input\', \'ksize\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "quantized_relu_x" + argspec: "args=[\'features\', \'max_value\', \'min_features\', \'max_features\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "raw_rnn" + argspec: "args=[\'cell\', \'loop_fn\', \'parallel_iterations\', \'swap_memory\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "relu" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "relu6" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "relu_layer" + argspec: "args=[\'x\', \'weights\', \'biases\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sampled_softmax_loss" + argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\'], " + } + member_method { + name: "separable_conv2d" + argspec: "args=[\'input\', \'depthwise_filter\', \'pointwise_filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "sigmoid" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sigmoid_cross_entropy_with_logits" + argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "softmax" + argspec: "args=[\'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + } + member_method { + name: "softmax_cross_entropy_with_logits" + argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'-1\', \'None\'], " + } + member_method { + name: "softplus" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softsign" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_softmax_cross_entropy_with_logits" + argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "static_bidirectional_rnn" + argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "static_rnn" + argspec: "args=[\'cell\', \'inputs\', \'initial_state\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "static_state_saving_rnn" + argspec: "args=[\'cell\', \'inputs\', \'state_saver\', \'state_name\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "sufficient_statistics" + argspec: "args=[\'x\', \'axes\', \'shift\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "tanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "top_k" + argspec: "args=[\'input\', \'k\', \'sorted\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'True\', \'None\'], " + } + member_method { + name: "uniform_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "weighted_cross_entropy_with_logits" + argspec: "args=[\'targets\', \'logits\', \'pos_weight\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "weighted_moments" + argspec: "args=[\'x\', \'axes\', \'frequency_weights\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "with_space_to_batch" + argspec: "args=[\'input\', \'dilation_rate\', \'padding\', \'op\', \'filter_shape\', \'spatial_dims\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "xw_plus_b" + argspec: "args=[\'x\', \'weights\', \'biases\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "zero_fraction" + argspec: "args=[\'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt new file mode 100644 index 00000000000..fbf68c50a1a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -0,0 +1,95 @@ +path: "tensorflow.nn.rnn_cell.BasicLSTMCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_units\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'1.0\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt new file mode 100644 index 00000000000..606d20d8f0f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -0,0 +1,95 @@ +path: "tensorflow.nn.rnn_cell.BasicRNNCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt new file mode 100644 index 00000000000..ead1d0cfc51 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -0,0 +1,95 @@ +path: "tensorflow.nn.rnn_cell.DeviceWrapper" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cell\', \'device\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt new file mode 100644 index 00000000000..2db4996b2a4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -0,0 +1,95 @@ +path: "tensorflow.nn.rnn_cell.DropoutWrapper" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cell\', \'input_keep_prob\', \'output_keep_prob\', \'state_keep_prob\', \'variational_recurrent\', \'input_size\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'1.0\', \'False\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt new file mode 100644 index 00000000000..101f6df1d84 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -0,0 +1,95 @@ +path: "tensorflow.nn.rnn_cell.GRUCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'kernel_initializer\', \'bias_initializer\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt new file mode 100644 index 00000000000..c87546d5285 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -0,0 +1,95 @@ +path: "tensorflow.nn.rnn_cell.LSTMCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_units\', \'use_peepholes\', \'cell_clip\', \'initializer\', \'num_proj\', \'proj_clip\', \'num_unit_shards\', \'num_proj_shards\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1.0\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt new file mode 100644 index 00000000000..1de8a55dcca --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.nn.rnn_cell.LSTMStateTuple" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "c" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "h" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt new file mode 100644 index 00000000000..bc01ccfa647 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -0,0 +1,95 @@ +path: "tensorflow.nn.rnn_cell.MultiRNNCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt new file mode 100644 index 00000000000..b19ee18b40f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -0,0 +1,94 @@ +path: "tensorflow.nn.rnn_cell.RNNCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'trainable\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\', \"\"], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt new file mode 100644 index 00000000000..b21d9a8ee33 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -0,0 +1,95 @@ +path: "tensorflow.nn.rnn_cell.ResidualWrapper" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cell\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt new file mode 100644 index 00000000000..64697e8a02b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt @@ -0,0 +1,43 @@ +path: "tensorflow.nn.rnn_cell" +tf_module { + member { + name: "BasicLSTMCell" + mtype: "" + } + member { + name: "BasicRNNCell" + mtype: "" + } + member { + name: "DeviceWrapper" + mtype: "" + } + member { + name: "DropoutWrapper" + mtype: "" + } + member { + name: "GRUCell" + mtype: "" + } + member { + name: "LSTMCell" + mtype: "" + } + member { + name: "LSTMStateTuple" + mtype: "" + } + member { + name: "MultiRNNCell" + mtype: "" + } + member { + name: "RNNCell" + mtype: "" + } + member { + name: "ResidualWrapper" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt new file mode 100644 index 00000000000..210b56242b2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.ones_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt new file mode 100644 index 00000000000..13ec7454f41 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.orthogonal_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt new file mode 100644 index 00000000000..342ee95f74d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -0,0 +1,1967 @@ +path: "tensorflow" +tf_module { + member { + name: "AggregationMethod" + mtype: "" + } + member { + name: "AttrValue" + mtype: "" + } + member { + name: "COMPILER_VERSION" + mtype: "" + } + member { + name: "ConditionalAccumulator" + mtype: "" + } + member { + name: "ConditionalAccumulatorBase" + mtype: "" + } + member { + name: "ConfigProto" + mtype: "" + } + member { + name: "DType" + mtype: "" + } + member { + name: "DeviceSpec" + mtype: "" + } + member { + name: "Dimension" + mtype: "" + } + member { + name: "Event" + mtype: "" + } + member { + name: "FIFOQueue" + mtype: "" + } + member { + name: "FixedLenFeature" + mtype: "" + } + member { + name: "FixedLenSequenceFeature" + mtype: "" + } + member { + name: "FixedLengthRecordReader" + mtype: "" + } + member { + name: "GIT_VERSION" + mtype: "" + } + member { + name: "GPUOptions" + mtype: "" + } + member { + name: "GRAPH_DEF_VERSION" + mtype: "" + } + member { + name: "GRAPH_DEF_VERSION_MIN_CONSUMER" + mtype: "" + } + member { + name: "GRAPH_DEF_VERSION_MIN_PRODUCER" + mtype: "" + } + member { + name: "Graph" + mtype: "" + } + member { + name: "GraphDef" + mtype: "" + } + member { + name: "GraphKeys" + mtype: "" + } + member { + name: "GraphOptions" + mtype: "" + } + member { + name: "HistogramProto" + mtype: "" + } + member { + name: "IdentityReader" + mtype: "" + } + member { + name: "IndexedSlices" + mtype: "" + } + member { + name: "InteractiveSession" + mtype: "" + } + member { + name: "LogMessage" + mtype: "" + } + member { + name: "MetaGraphDef" + mtype: "" + } + member { + name: "NameAttrList" + mtype: "" + } + member { + name: "NodeDef" + mtype: "" + } + member { + name: "OpError" + mtype: "" + } + member { + name: "Operation" + mtype: "" + } + member { + name: "OptimizerOptions" + mtype: "" + } + member { + name: "PaddingFIFOQueue" + mtype: "" + } + member { + name: "PriorityQueue" + mtype: "" + } + member { + name: "QUANTIZED_DTYPES" + mtype: "" + } + member { + name: "QueueBase" + mtype: "" + } + member { + name: "RandomShuffleQueue" + mtype: "" + } + member { + name: "ReaderBase" + mtype: "" + } + member { + name: "RegisterGradient" + mtype: "" + } + member { + name: "RunMetadata" + mtype: "" + } + member { + name: "RunOptions" + mtype: "" + } + member { + name: "Session" + mtype: "" + } + member { + name: "SessionLog" + mtype: "" + } + member { + name: "SparseConditionalAccumulator" + mtype: "" + } + member { + name: "SparseFeature" + mtype: "" + } + member { + name: "SparseTensor" + mtype: "" + } + member { + name: "SparseTensorValue" + mtype: "" + } + member { + name: "Summary" + mtype: "" + } + member { + name: "TFRecordReader" + mtype: "" + } + member { + name: "Tensor" + mtype: "" + } + member { + name: "TensorArray" + mtype: "" + } + member { + name: "TensorInfo" + mtype: "" + } + member { + name: "TensorShape" + mtype: "" + } + member { + name: "TextLineReader" + mtype: "" + } + member { + name: "VERSION" + mtype: "" + } + member { + name: "VarLenFeature" + mtype: "" + } + member { + name: "Variable" + mtype: "" + } + member { + name: "VariableScope" + mtype: "" + } + member { + name: "WholeFileReader" + mtype: "" + } + member { + name: "app" + mtype: "" + } + member { + name: "bfloat16" + mtype: "" + } + member { + name: "bool" + mtype: "" + } + member { + name: "compat" + mtype: "" + } + member { + name: "complex128" + mtype: "" + } + member { + name: "complex64" + mtype: "" + } + member { + name: "constant_initializer" + mtype: "" + } + member { + name: "contrib" + mtype: "" + } + member { + name: "double" + mtype: "" + } + member { + name: "errors" + mtype: "" + } + member { + name: "estimator" + mtype: "" + } + member { + name: "feature_column" + mtype: "" + } + member { + name: "flags" + mtype: "" + } + member { + name: "float16" + mtype: "" + } + member { + name: "float32" + mtype: "" + } + member { + name: "float64" + mtype: "" + } + member { + name: "gfile" + mtype: "" + } + member { + name: "graph_util" + mtype: "" + } + member { + name: "half" + mtype: "" + } + member { + name: "image" + mtype: "" + } + member { + name: "int16" + mtype: "" + } + member { + name: "int32" + mtype: "" + } + member { + name: "int64" + mtype: "" + } + member { + name: "int8" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "logging" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "newaxis" + mtype: "" + } + member { + name: "nn" + mtype: "" + } + member { + name: "ones_initializer" + mtype: "" + } + member { + name: "orthogonal_initializer" + mtype: "" + } + member { + name: "python_io" + mtype: "" + } + member { + name: "pywrap_tensorflow" + mtype: "" + } + member { + name: "qint16" + mtype: "" + } + member { + name: "qint32" + mtype: "" + } + member { + name: "qint8" + mtype: "" + } + member { + name: "quint16" + mtype: "" + } + member { + name: "quint8" + mtype: "" + } + member { + name: "random_normal_initializer" + mtype: "" + } + member { + name: "random_uniform_initializer" + mtype: "" + } + member { + name: "resource" + mtype: "" + } + member { + name: "resource_loader" + mtype: "" + } + member { + name: "saved_model" + mtype: "" + } + member { + name: "sets" + mtype: "" + } + member { + name: "spectral" + mtype: "" + } + member { + name: "string" + mtype: "" + } + member { + name: "summary" + mtype: "" + } + member { + name: "sysconfig" + mtype: "" + } + member { + name: "test" + mtype: "" + } + member { + name: "train" + mtype: "" + } + member { + name: "truncated_normal_initializer" + mtype: "" + } + member { + name: "uint16" + mtype: "" + } + member { + name: "uint8" + mtype: "" + } + member { + name: "uniform_unit_scaling_initializer" + mtype: "" + } + member { + name: "user_ops" + mtype: "" + } + member { + name: "zeros_initializer" + mtype: "" + } + member_method { + name: "Assert" + argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "NoGradient" + argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "NotDifferentiable" + argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Print" + argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "abs" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "accumulate_n" + argspec: "args=[\'inputs\', \'shape\', \'tensor_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "acos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_check_numerics_ops" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_n" + argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_to_collection" + argspec: "args=[\'name\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "all_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "arg_max" + argspec: "args=[\'input\', \'dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "arg_min" + argspec: "args=[\'input\', \'dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "argmax" + argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "argmin" + argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "as_dtype" + argspec: "args=[\'type_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_string" + argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "asin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "assert_equal" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_greater" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_greater_equal" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_integer" + argspec: "args=[\'x\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "assert_less" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_less_equal" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_negative" + argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_non_negative" + argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_non_positive" + argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_none_equal" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_positive" + argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_proper_iterable" + argspec: "args=[\'values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_rank" + argspec: "args=[\'x\', \'rank\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_rank_at_least" + argspec: "args=[\'x\', \'rank\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_same_float_dtype" + argspec: "args=[\'tensors\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "assert_scalar" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "assert_type" + argspec: "args=[\'tensor\', \'tf_type\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "assert_variables_initialized" + argspec: "args=[\'var_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "assign" + argspec: "args=[\'ref\', \'value\', \'validate_shape\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "assign_add" + argspec: "args=[\'ref\', \'value\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "assign_sub" + argspec: "args=[\'ref\', \'value\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "atan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan2" + argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "batch_to_space" + argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "batch_to_space_nd" + argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "betainc" + argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bincount" + argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"\"], " + } + member_method { + name: "bitcast" + argspec: "args=[\'input\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "boolean_mask" + argspec: "args=[\'tensor\', \'mask\', \'name\'], varargs=None, keywords=None, defaults=[\'boolean_mask\'], " + } + member_method { + name: "broadcast_dynamic_shape" + argspec: "args=[\'shape_x\', \'shape_y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "broadcast_static_shape" + argspec: "args=[\'shape_x\', \'shape_y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "case" + argspec: "args=[\'pred_fn_pairs\', \'default\', \'exclusive\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'case\'], " + } + member_method { + name: "cast" + argspec: "args=[\'x\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ceil" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "check_numerics" + argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cholesky" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cholesky_solve" + argspec: "args=[\'chol\', \'rhs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "clip_by_average_norm" + argspec: "args=[\'t\', \'clip_norm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "clip_by_global_norm" + argspec: "args=[\'t_list\', \'clip_norm\', \'use_norm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "clip_by_norm" + argspec: "args=[\'t\', \'clip_norm\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "clip_by_value" + argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "complex" + argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "concat" + argspec: "args=[\'values\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'concat\'], " + } + member_method { + name: "cond" + argspec: "args=[\'pred\', \'true_fn\', \'false_fn\', \'strict\', \'name\', \'fn1\', \'fn2\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "confusion_matrix" + argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'dtype\', \'name\', \'weights\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\', \'None\'], " + } + member_method { + name: "conj" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "constant" + argspec: "args=[\'value\', \'dtype\', \'shape\', \'name\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Const\', \'False\'], " + } + member_method { + name: "container" + argspec: "args=[\'container_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "control_dependencies" + argspec: "args=[\'control_inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "convert_to_tensor" + argspec: "args=[\'value\', \'dtype\', \'name\', \'preferred_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "convert_to_tensor_or_indexed_slices" + argspec: "args=[\'value\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "convert_to_tensor_or_sparse_tensor" + argspec: "args=[\'value\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "cos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_nonzero" + argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'dtype\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \"\", \'None\', \'None\'], " + } + member_method { + name: "count_up_to" + argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "create_partitioned_variables" + argspec: "args=[\'shape\', \'slicing\', \'initializer\', \'dtype\', \'trainable\', \'collections\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\"\", \'True\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "cross" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cumprod" + argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "cumsum" + argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "decode_base64" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_csv" + argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "decode_json_example" + argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_raw" + argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "delete_session_tensor" + argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "depth_to_space" + argspec: "args=[\'input\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequantize" + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "deserialize_many_sparse" + argspec: "args=[\'serialized_sparse\', \'dtype\', \'rank\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "device" + argspec: "args=[\'device_name_or_function\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "diag" + argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "digamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "div" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "divide" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dynamic_partition" + argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dynamic_stitch" + argspec: "args=[\'indices\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "edit_distance" + argspec: "args=[\'hypothesis\', \'truth\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'edit_distance\'], " + } + member_method { + name: "einsum" + argspec: "args=[\'equation\'], varargs=inputs, keywords=None, defaults=None" + } + member_method { + name: "encode_base64" + argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "erf" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "erfc" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "exp" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "expand_dims" + argspec: "args=[\'input\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "expm1" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "extract_image_patches" + argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "eye" + argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "fft" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fft2d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fft3d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fill" + argspec: "args=[\'dims\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fixed_size_partitioner" + argspec: "args=[\'num_shards\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "floor" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floor_div" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floordiv" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floormod" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "foldl" + argspec: "args=[\'fn\', \'elems\', \'initializer\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\'], " + } + member_method { + name: "foldr" + argspec: "args=[\'fn\', \'elems\', \'initializer\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\'], " + } + member_method { + name: "gather" + argspec: "args=[\'params\', \'indices\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "gather_nd" + argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_collection" + argspec: "args=[\'key\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_collection_ref" + argspec: "args=[\'key\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_default_graph" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_default_session" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_local_variable" + argspec: "args=[], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "get_seed" + argspec: "args=[\'op_seed\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_session_handle" + argspec: "args=[\'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_session_tensor" + argspec: "args=[\'handle\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_variable" + argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "get_variable_scope" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "global_norm" + argspec: "args=[\'t_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "global_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "global_variables_initializer" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "gradients" + argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "greater" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "group" + argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None" + } + member_method { + name: "hessians" + argspec: "args=[\'ys\', \'xs\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\'], varargs=None, keywords=None, defaults=[\'hessians\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "histogram_fixed_width" + argspec: "args=[\'values\', \'value_range\', \'nbins\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'100\', \"\", \'None\'], " + } + member_method { + name: "identity" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft2d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft3d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igammac" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "imag" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "import_graph_def" + argspec: "args=[\'graph_def\', \'input_map\', \'return_elements\', \'name\', \'op_dict\', \'producer_op_list\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "initialize_all_tables" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], " + } + member_method { + name: "initialize_all_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "initialize_local_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "initialize_variables" + argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], " + } + member_method { + name: "invert_permutation" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_finite" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_inf" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_nan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_non_decreasing" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_numeric_tensor" + argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_strictly_increasing" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_variable_initialized" + argspec: "args=[\'variable\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "lbeta" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'lbeta\'], " + } + member_method { + name: "less" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lgamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lin_space" + argspec: "args=[\'start\', \'stop\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "linspace" + argspec: "args=[\'start\', \'stop\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "load_file_system_library" + argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_op_library" + argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "local_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "local_variables_initializer" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "log" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log1p" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log_sigmoid" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_not" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_xor" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'LogicalXor\'], " + } + member_method { + name: "make_ndarray" + argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_template" + argspec: "args=[\'name_\', \'func_\', \'create_scope_now_\', \'unique_name_\', \'custom_getter_\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\', \'None\'], " + } + member_method { + name: "make_tensor_proto" + argspec: "args=[\'values\', \'dtype\', \'shape\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + } + member_method { + name: "map_fn" + argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'None\'], " + } + member_method { + name: "matching_files" + argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "matrix_band_part" + argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_determinant" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_diag" + argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_diag_part" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_inverse" + argspec: "args=[\'input\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "matrix_set_diag" + argspec: "args=[\'input\', \'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_solve" + argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "matrix_solve_ls" + argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], " + } + member_method { + name: "matrix_transpose" + argspec: "args=[\'a\', \'name\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\'], " + } + member_method { + name: "matrix_triangular_solve" + argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "maximum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "meshgrid" + argspec: "args=[], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "min_max_variable_partitioner" + argspec: "args=[\'max_partitions\', \'axis\', \'min_slice_size\', \'bytes_per_string_element\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'262144\', \'16\'], " + } + member_method { + name: "minimum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "mod" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "model_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "moving_average_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "multinomial" + argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "multiply" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "name_scope" + argspec: "args=[\'name\', \'default_name\', \'values\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "negative" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "no_op" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "no_regularizer" + argspec: "args=[\'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "norm" + argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "not_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "one_hot" + argspec: "args=[\'indices\', \'depth\', \'on_value\', \'off_value\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "ones" + argspec: "args=[\'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "ones_like" + argspec: "args=[\'tensor\', \'dtype\', \'name\', \'optimize\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " + } + member_method { + name: "op_scope" + argspec: "args=[\'values\', \'name\', \'default_name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "pad" + argspec: "args=[\'tensor\', \'paddings\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], " + } + member_method { + name: "parallel_stack" + argspec: "args=[\'values\', \'name\'], varargs=None, keywords=None, defaults=[\'parallel_stack\'], " + } + member_method { + name: "parse_example" + argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "parse_single_example" + argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "parse_single_sequence_example" + argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "parse_tensor" + argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "placeholder" + argspec: "args=[\'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "placeholder_with_default" + argspec: "args=[\'input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polygamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "pow" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "py_func" + argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "qr" + argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "quantize_v2" + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "quantized_concat" + argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_crop" + argspec: "args=[\'value\', \'size\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "random_gamma" + argspec: "args=[\'shape\', \'alpha\', \'beta\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\', \'None\'], " + } + member_method { + name: "random_normal" + argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"\", \'None\', \'None\'], " + } + member_method { + name: "random_poisson" + argspec: "args=[\'lam\', \'shape\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'None\'], " + } + member_method { + name: "random_shuffle" + argspec: "args=[\'value\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "random_uniform" + argspec: "args=[\'shape\', \'minval\', \'maxval\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \"\", \'None\', \'None\'], " + } + member_method { + name: "range" + argspec: "args=[\'start\', \'limit\', \'delta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'range\'], " + } + member_method { + name: "rank" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_file" + argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "real" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "realdiv" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reciprocal" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reduce_all" + argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "reduce_any" + argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "reduce_join" + argspec: "args=[\'inputs\', \'axis\', \'keep_dims\', \'separator\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\', \'None\'], " + } + member_method { + name: "reduce_logsumexp" + argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "reduce_max" + argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "reduce_mean" + argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "reduce_min" + argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "reduce_prod" + argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "reduce_sum" + argspec: "args=[\'input_tensor\', \'axis\', \'keep_dims\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "register_tensor_conversion_function" + argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], " + } + member_method { + name: "report_uninitialized_variables" + argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'report_uninitialized_variables\'], " + } + member_method { + name: "required_space_to_batch_paddings" + argspec: "args=[\'input_shape\', \'block_shape\', \'base_paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "reset_default_graph" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reshape" + argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reverse" + argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reverse_sequence" + argspec: "args=[\'input\', \'seq_lengths\', \'seq_axis\', \'batch_axis\', \'name\', \'seq_dim\', \'batch_dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "reverse_v2" + argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rint" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "round" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rsqrt" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "saturate_cast" + argspec: "args=[\'value\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "scalar_mul" + argspec: "args=[\'scalar\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "scan" + argspec: "args=[\'fn\', \'elems\', \'initializer\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'None\'], " + } + member_method { + name: "scatter_add" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "scatter_div" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "scatter_mul" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "scatter_nd" + argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "scatter_nd_add" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "scatter_nd_sub" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "scatter_nd_update" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "scatter_sub" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "scatter_update" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_mean" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "self_adjoint_eig" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "self_adjoint_eigvals" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sequence_mask" + argspec: "args=[\'lengths\', \'maxlen\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\'], " + } + member_method { + name: "serialize_many_sparse" + argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_sparse" + argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_random_seed" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "setdiff1d" + argspec: "args=[\'x\', \'y\', \'index_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "shape" + argspec: "args=[\'input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "shape_n" + argspec: "args=[\'input\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "sigmoid" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sign" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "slice" + argspec: "args=[\'input_\', \'begin\', \'size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_batch" + argspec: "args=[\'input\', \'paddings\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_batch_nd" + argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_depth" + argspec: "args=[\'input\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_add" + argspec: "args=[\'a\', \'b\', \'thresh\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "sparse_concat" + argspec: "args=[\'axis\', \'sp_inputs\', \'name\', \'expand_nonconcat_dim\', \'concat_dim\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "sparse_fill_empty_rows" + argspec: "args=[\'sp_input\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_mask" + argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_matmul" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_maximum" + argspec: "args=[\'sp_a\', \'sp_b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_merge" + argspec: "args=[\'sp_ids\', \'sp_values\', \'vocab_size\', \'name\', \'already_sorted\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "sparse_minimum" + argspec: "args=[\'sp_a\', \'sp_b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_placeholder" + argspec: "args=[\'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "sparse_reduce_sum" + argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "sparse_reduce_sum_sparse" + argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "sparse_reorder" + argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_reset_shape" + argspec: "args=[\'sp_input\', \'new_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_reshape" + argspec: "args=[\'sp_input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_retain" + argspec: "args=[\'sp_input\', \'to_retain\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sparse_segment_mean" + argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_segment_sqrt_n" + argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_segment_sum" + argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_softmax" + argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_split" + argspec: "args=[\'keyword_required\', \'sp_input\', \'num_split\', \'axis\', \'name\', \'split_dim\'], varargs=None, keywords=None, defaults=[\'KeywordRequired()\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_tensor_dense_matmul" + argspec: "args=[\'sp_a\', \'b\', \'adjoint_a\', \'adjoint_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + } + member_method { + name: "sparse_tensor_to_dense" + argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], " + } + member_method { + name: "sparse_to_dense" + argspec: "args=[\'sparse_indices\', \'output_shape\', \'sparse_values\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], " + } + member_method { + name: "sparse_to_indicator" + argspec: "args=[\'sp_input\', \'vocab_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_transpose" + argspec: "args=[\'sp_input\', \'perm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "split" + argspec: "args=[\'value\', \'num_or_size_splits\', \'axis\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'split\'], " + } + member_method { + name: "sqrt" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "square" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "squared_difference" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "squeeze" + argspec: "args=[\'input\', \'axis\', \'name\', \'squeeze_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "stack" + argspec: "args=[\'values\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'stack\'], " + } + member_method { + name: "stop_gradient" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "strided_slice" + argspec: "args=[\'input_\', \'begin\', \'end\', \'strides\', \'begin_mask\', \'end_mask\', \'ellipsis_mask\', \'new_axis_mask\', \'shrink_axis_mask\', \'var\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'0\', \'0\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "string_join" + argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "string_split" + argspec: "args=[\'source\', \'delimiter\'], varargs=None, keywords=None, defaults=[\' \'], " + } + member_method { + name: "string_to_hash_bucket" + argspec: "args=[\'string_tensor\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "string_to_hash_bucket_fast" + argspec: "args=[\'input\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "string_to_hash_bucket_strong" + argspec: "args=[\'input\', \'num_buckets\', \'key\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "string_to_number" + argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "substr" + argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "subtract" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "svd" + argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], " + } + member_method { + name: "tables_initializer" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], " + } + member_method { + name: "tan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tensordot" + argspec: "args=[\'a\', \'b\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tile" + argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_bfloat16" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToBFloat16\'], " + } + member_method { + name: "to_double" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToDouble\'], " + } + member_method { + name: "to_float" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToFloat\'], " + } + member_method { + name: "to_int32" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToInt32\'], " + } + member_method { + name: "to_int64" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToInt64\'], " + } + member_method { + name: "trace" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "trainable_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "transpose" + argspec: "args=[\'a\', \'perm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'transpose\'], " + } + member_method { + name: "truediv" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "truncated_normal" + argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"\", \'None\', \'None\'], " + } + member_method { + name: "truncatediv" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "truncatemod" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tuple" + argspec: "args=[\'tensors\', \'name\', \'control_inputs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "unique" + argspec: "args=[\'x\', \'out_idx\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "unique_with_counts" + argspec: "args=[\'x\', \'out_idx\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "unsorted_segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unstack" + argspec: "args=[\'value\', \'num\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'unstack\'], " + } + member_method { + name: "variable_axis_size_partitioner" + argspec: "args=[\'max_shard_bytes\', \'axis\', \'bytes_per_string_element\', \'max_shards\'], varargs=None, keywords=None, defaults=[\'0\', \'16\', \'None\'], " + } + member_method { + name: "variable_op_scope" + argspec: "args=[\'values\', \'name_or_scope\', \'default_name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "variable_scope" + argspec: "args=[\'name_or_scope\', \'default_name\', \'values\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "variables_initializer" + argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], " + } + member_method { + name: "verify_tensor_all_finite" + argspec: "args=[\'t\', \'msg\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "where" + argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "while_loop" + argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\'], " + } + member_method { + name: "write_file" + argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "zeros" + argspec: "args=[\'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "zeros_like" + argspec: "args=[\'tensor\', \'dtype\', \'name\', \'optimize\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " + } + member_method { + name: "zeta" + argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-compression-type.pbtxt b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-compression-type.pbtxt new file mode 100644 index 00000000000..4941dda50e4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-compression-type.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.python_io.TFRecordCompressionType" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "GZIP" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "ZLIB" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-options.pbtxt new file mode 100644 index 00000000000..0853716023a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-options.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.python_io.TFRecordOptions" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "compression_type_map" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_compression_type_string" + argspec: "args=[\'cls\', \'options\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt new file mode 100644 index 00000000000..af0c11ca14d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.python_io.TFRecordWriter" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "write" + argspec: "args=[\'self\', \'record\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.pbtxt b/tensorflow/tools/api/golden/tensorflow.python_io.pbtxt new file mode 100644 index 00000000000..7c9953e5fe3 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.python_io.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.python_io" +tf_module { + member { + name: "TFRecordCompressionType" + mtype: "" + } + member { + name: "TFRecordOptions" + mtype: "" + } + member { + name: "TFRecordWriter" + mtype: "" + } + member_method { + name: "tf_record_iterator" + argspec: "args=[\'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt new file mode 100644 index 00000000000..5993fdeb9c2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.random_normal_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt new file mode 100644 index 00000000000..a434ed1599e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.random_uniform_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.resource_loader.pbtxt b/tensorflow/tools/api/golden/tensorflow.resource_loader.pbtxt new file mode 100644 index 00000000000..288b78b4cd0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.resource_loader.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.resource_loader" +tf_module { + member_method { + name: "get_data_files_path" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_path_to_datafile" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_root_dir_with_all_resources" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_resource" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "readahead_file_path" + argspec: "args=[\'path\', \'readahead\'], varargs=None, keywords=None, defaults=[\'128M\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt new file mode 100644 index 00000000000..56d76902fd0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.saved_model.builder.SavedModelBuilder" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'export_dir\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_meta_graph" + argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "add_meta_graph_and_variables" + argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "save" + argspec: "args=[\'self\', \'as_text\'], varargs=None, keywords=None, defaults=[\'False\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.pbtxt new file mode 100644 index 00000000000..adc697ad1c0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.saved_model.builder" +tf_module { + member { + name: "SavedModelBuilder" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.constants.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.constants.pbtxt new file mode 100644 index 00000000000..20e10aa094f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.constants.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.saved_model.constants" +tf_module { + member { + name: "ASSETS_DIRECTORY" + mtype: "" + } + member { + name: "ASSETS_KEY" + mtype: "" + } + member { + name: "LEGACY_INIT_OP_KEY" + mtype: "" + } + member { + name: "MAIN_OP_KEY" + mtype: "" + } + member { + name: "SAVED_MODEL_FILENAME_PB" + mtype: "" + } + member { + name: "SAVED_MODEL_FILENAME_PBTXT" + mtype: "" + } + member { + name: "SAVED_MODEL_SCHEMA_VERSION" + mtype: "" + } + member { + name: "VARIABLES_DIRECTORY" + mtype: "" + } + member { + name: "VARIABLES_FILENAME" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt new file mode 100644 index 00000000000..896e2160c69 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.saved_model.loader" +tf_module { + member_method { + name: "load" + argspec: "args=[\'sess\', \'tags\', \'export_dir\'], varargs=None, keywords=saver_kwargs, defaults=None" + } + member_method { + name: "maybe_saved_model_directory" + argspec: "args=[\'export_dir\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.main_op.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.main_op.pbtxt new file mode 100644 index 00000000000..176cb788c24 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.main_op.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.saved_model.main_op" +tf_module { + member_method { + name: "main_op" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "main_op_with_restore" + argspec: "args=[\'restore_op_name\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.pbtxt new file mode 100644 index 00000000000..5683766b289 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.saved_model" +tf_module { + member { + name: "builder" + mtype: "" + } + member { + name: "constants" + mtype: "" + } + member { + name: "loader" + mtype: "" + } + member { + name: "main_op" + mtype: "" + } + member { + name: "signature_constants" + mtype: "" + } + member { + name: "signature_def_utils" + mtype: "" + } + member { + name: "tag_constants" + mtype: "" + } + member { + name: "utils" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.signature_constants.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.signature_constants.pbtxt new file mode 100644 index 00000000000..478d410e066 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.signature_constants.pbtxt @@ -0,0 +1,47 @@ +path: "tensorflow.saved_model.signature_constants" +tf_module { + member { + name: "CLASSIFY_INPUTS" + mtype: "" + } + member { + name: "CLASSIFY_METHOD_NAME" + mtype: "" + } + member { + name: "CLASSIFY_OUTPUT_CLASSES" + mtype: "" + } + member { + name: "CLASSIFY_OUTPUT_SCORES" + mtype: "" + } + member { + name: "DEFAULT_SERVING_SIGNATURE_DEF_KEY" + mtype: "" + } + member { + name: "PREDICT_INPUTS" + mtype: "" + } + member { + name: "PREDICT_METHOD_NAME" + mtype: "" + } + member { + name: "PREDICT_OUTPUTS" + mtype: "" + } + member { + name: "REGRESS_INPUTS" + mtype: "" + } + member { + name: "REGRESS_METHOD_NAME" + mtype: "" + } + member { + name: "REGRESS_OUTPUTS" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.signature_def_utils.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.signature_def_utils.pbtxt new file mode 100644 index 00000000000..e9867d84c3e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.signature_def_utils.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.saved_model.signature_def_utils" +tf_module { + member_method { + name: "build_signature_def" + argspec: "args=[\'inputs\', \'outputs\', \'method_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "classification_signature_def" + argspec: "args=[\'examples\', \'classes\', \'scores\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict_signature_def" + argspec: "args=[\'inputs\', \'outputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "regression_signature_def" + argspec: "args=[\'examples\', \'predictions\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt new file mode 100644 index 00000000000..7c24b7ad3cf --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.saved_model.tag_constants" +tf_module { + member { + name: "SERVING" + mtype: "" + } + member { + name: "TRAINING" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.utils.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.utils.pbtxt new file mode 100644 index 00000000000..bc150e56a36 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.utils.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.saved_model.utils" +tf_module { + member_method { + name: "build_tensor_info" + argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.sets.pbtxt b/tensorflow/tools/api/golden/tensorflow.sets.pbtxt new file mode 100644 index 00000000000..8a196b1a556 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.sets.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.sets" +tf_module { + member_method { + name: "set_difference" + argspec: "args=[\'a\', \'b\', \'aminusb\', \'validate_indices\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "set_intersection" + argspec: "args=[\'a\', \'b\', \'validate_indices\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "set_size" + argspec: "args=[\'a\', \'validate_indices\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "set_union" + argspec: "args=[\'a\', \'b\', \'validate_indices\'], varargs=None, keywords=None, defaults=[\'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt new file mode 100644 index 00000000000..84883c1a395 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.spectral" +tf_module { + member_method { + name: "fft" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fft2d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fft3d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft2d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft3d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "irfft" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "irfft2d" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "irfft3d" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "rfft" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "rfft2d" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "rfft3d" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt new file mode 100644 index 00000000000..ab3449d80f6 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt @@ -0,0 +1,112 @@ +path: "tensorflow.summary.Event" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FILE_VERSION_FIELD_NUMBER" + mtype: "" + } + member { + name: "GRAPH_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "LOG_MESSAGE_FIELD_NUMBER" + mtype: "" + } + member { + name: "META_GRAPH_DEF_FIELD_NUMBER" + mtype: "" + } + member { + name: "SESSION_LOG_FIELD_NUMBER" + mtype: "" + } + member { + name: "STEP_FIELD_NUMBER" + mtype: "" + } + member { + name: "SUMMARY_FIELD_NUMBER" + mtype: "" + } + member { + name: "TAGGED_RUN_METADATA_FIELD_NUMBER" + mtype: "" + } + member { + name: "WALL_TIME_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-file-writer-cache.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-file-writer-cache.pbtxt new file mode 100644 index 00000000000..2a5b63dceae --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-file-writer-cache.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.summary.FileWriterCache" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "clear" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get" + argspec: "args=[\'logdir\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt new file mode 100644 index 00000000000..dcf747971b7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt @@ -0,0 +1,50 @@ +path: "tensorflow.summary.FileWriter" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'logdir\', \'graph\', \'max_queue\', \'flush_secs\', \'graph_def\', \'filename_suffix\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'120\', \'None\', \'None\'], " + } + member_method { + name: "add_event" + argspec: "args=[\'self\', \'event\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_graph" + argspec: "args=[\'self\', \'graph\', \'global_step\', \'graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_meta_graph" + argspec: "args=[\'self\', \'meta_graph_def\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_run_metadata" + argspec: "args=[\'self\', \'run_metadata\', \'tag\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_session_log" + argspec: "args=[\'self\', \'session_log\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_summary" + argspec: "args=[\'self\', \'summary\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_logdir" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reopen" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt new file mode 100644 index 00000000000..92ca4872caf --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt @@ -0,0 +1,108 @@ +path: "tensorflow.summary.SessionLog" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "CHECKPOINT" + mtype: "" + } + member { + name: "CHECKPOINT_PATH_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "MSG_FIELD_NUMBER" + mtype: "" + } + member { + name: "START" + mtype: "" + } + member { + name: "STATUS_FIELD_NUMBER" + mtype: "" + } + member { + name: "STATUS_UNSPECIFIED" + mtype: "" + } + member { + name: "STOP" + mtype: "" + } + member { + name: "SessionStatus" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt new file mode 100644 index 00000000000..f93da2196ad --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.summary.SummaryDescription" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "TYPE_HINT_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt new file mode 100644 index 00000000000..605e305e82c --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt @@ -0,0 +1,96 @@ +path: "tensorflow.summary.Summary.Audio" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "CONTENT_TYPE_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "ENCODED_AUDIO_STRING_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "LENGTH_FRAMES_FIELD_NUMBER" + mtype: "" + } + member { + name: "NUM_CHANNELS_FIELD_NUMBER" + mtype: "" + } + member { + name: "SAMPLE_RATE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt new file mode 100644 index 00000000000..0646972196d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt @@ -0,0 +1,92 @@ +path: "tensorflow.summary.Summary.Image" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "COLORSPACE_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "ENCODED_IMAGE_STRING_FIELD_NUMBER" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "HEIGHT_FIELD_NUMBER" + mtype: "" + } + member { + name: "WIDTH_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt new file mode 100644 index 00000000000..b319cd03d9e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt @@ -0,0 +1,112 @@ +path: "tensorflow.summary.Summary.Value" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "AUDIO_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "HISTO_FIELD_NUMBER" + mtype: "" + } + member { + name: "IMAGE_FIELD_NUMBER" + mtype: "" + } + member { + name: "METADATA_FIELD_NUMBER" + mtype: "" + } + member { + name: "NODE_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "OBSOLETE_OLD_STYLE_HISTOGRAM_FIELD_NUMBER" + mtype: "" + } + member { + name: "SIMPLE_VALUE_FIELD_NUMBER" + mtype: "" + } + member { + name: "TAG_FIELD_NUMBER" + mtype: "" + } + member { + name: "TENSOR_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt new file mode 100644 index 00000000000..132ef1b7d2e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt @@ -0,0 +1,92 @@ +path: "tensorflow.summary.Summary" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "Audio" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "Image" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member { + name: "Value" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt new file mode 100644 index 00000000000..4dce20819de --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.summary.TaggedRunMetadata" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "RUN_METADATA_FIELD_NUMBER" + mtype: "" + } + member { + name: "TAG_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.pbtxt new file mode 100644 index 00000000000..19d822e61bf --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.summary.pbtxt @@ -0,0 +1,67 @@ +path: "tensorflow.summary" +tf_module { + member { + name: "Event" + mtype: "" + } + member { + name: "FileWriter" + mtype: "" + } + member { + name: "FileWriterCache" + mtype: "" + } + member { + name: "SessionLog" + mtype: "" + } + member { + name: "Summary" + mtype: "" + } + member { + name: "SummaryDescription" + mtype: "" + } + member { + name: "TaggedRunMetadata" + mtype: "" + } + member_method { + name: "audio" + argspec: "args=[\'name\', \'tensor\', \'sample_rate\', \'max_outputs\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'3\', \'None\', \'None\'], " + } + member_method { + name: "get_summary_description" + argspec: "args=[\'node_def\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "histogram" + argspec: "args=[\'name\', \'values\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "image" + argspec: "args=[\'name\', \'tensor\', \'max_outputs\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'3\', \'None\', \'None\'], " + } + member_method { + name: "merge" + argspec: "args=[\'inputs\', \'collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "merge_all" + argspec: "args=[\'key\'], varargs=None, keywords=None, defaults=[\'summaries\'], " + } + member_method { + name: "scalar" + argspec: "args=[\'name\', \'tensor\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "tensor_summary" + argspec: "args=[\'name\', \'tensor\', \'summary_description\', \'collections\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "text" + argspec: "args=[\'name\', \'tensor\', \'collections\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.sysconfig.pbtxt b/tensorflow/tools/api/golden/tensorflow.sysconfig.pbtxt new file mode 100644 index 00000000000..02dec04b9cc --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.sysconfig.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.sysconfig" +tf_module { + member_method { + name: "get_include" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_lib" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.test.-benchmark.pbtxt b/tensorflow/tools/api/golden/tensorflow.test.-benchmark.pbtxt new file mode 100644 index 00000000000..df528e26b60 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.test.-benchmark.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.test.Benchmark" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "is_abstract" + argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "report_benchmark" + argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "run_op_benchmark" + argspec: "args=[\'self\', \'sess\', \'op_or_tensor\', \'feed_dict\', \'burn_iters\', \'min_iters\', \'store_trace\', \'store_memory_usage\', \'name\', \'extras\', \'mbs\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'10\', \'False\', \'True\', \'None\', \'None\', \'0\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.test.-stub-out-for-testing.pbtxt b/tensorflow/tools/api/golden/tensorflow.test.-stub-out-for-testing.pbtxt new file mode 100644 index 00000000000..e02a0c6097c --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.test.-stub-out-for-testing.pbtxt @@ -0,0 +1,28 @@ +path: "tensorflow.test.StubOutForTesting" +tf_class { + is_instance: "" + member_method { + name: "CleanUp" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Set" + argspec: "args=[\'self\', \'parent\', \'child_name\', \'new_child\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "SmartSet" + argspec: "args=[\'self\', \'obj\', \'attr_name\', \'new_attr\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "SmartUnsetAll" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "UnsetAll" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/tensorflow.test.pbtxt new file mode 100644 index 00000000000..2a88f26ed02 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.test.pbtxt @@ -0,0 +1,59 @@ +path: "tensorflow.test" +tf_module { + member { + name: "Benchmark" + mtype: "" + } + member { + name: "StubOutForTesting" + mtype: "" + } + member { + name: "TestCase" + mtype: "" + } + member { + name: "mock" + mtype: "" + } + member_method { + name: "assert_equal_graph_def" + argspec: "args=[\'actual\', \'expected\', \'checkpoint_v2\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "compute_gradient" + argspec: "args=[\'x\', \'x_shape\', \'y\', \'y_shape\', \'x_init_value\', \'delta\', \'init_targets\', \'extra_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'0.001\', \'None\', \'None\'], " + } + member_method { + name: "compute_gradient_error" + argspec: "args=[\'x\', \'x_shape\', \'y\', \'y_shape\', \'x_init_value\', \'delta\', \'init_targets\', \'extra_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'0.001\', \'None\', \'None\'], " + } + member_method { + name: "create_local_cluster" + argspec: "args=[\'num_workers\', \'num_ps\', \'protocol\'], varargs=None, keywords=None, defaults=[\'grpc\'], " + } + member_method { + name: "get_temp_dir" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "gpu_device_name" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_built_with_cuda" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_gpu_available" + argspec: "args=[\'cuda_only\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "main" + argspec: "args=[\'argv\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "test_src_dir_path" + argspec: "args=[\'relative_path\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt new file mode 100644 index 00000000000..8c91c5b4d9e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.AdadeltaOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'rho\', \'epsilon\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.95\', \'1e-08\', \'False\', \'Adadelta\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt new file mode 100644 index 00000000000..05d38d62ccd --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.AdagradDAOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'global_step\', \'initial_gradient_squared_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.1\', \'0.0\', \'0.0\', \'False\', \'AdagradDA\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt new file mode 100644 index 00000000000..19ca9f57637 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.AdagradOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\', \'Adagrad\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt new file mode 100644 index 00000000000..c8144e2db78 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.AdamOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-08\', \'False\', \'Adam\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt new file mode 100644 index 00000000000..8cf52b817f3 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.train.BytesList" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-hook.pbtxt new file mode 100644 index 00000000000..c3037baa8c9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.CheckpointSaverHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-listener.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-listener.pbtxt new file mode 100644 index 00000000000..9d3688e5657 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-listener.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.train.CheckpointSaverListener" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "after_save" + argspec: "args=[\'self\', \'session\', \'global_step_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_save" + argspec: "args=[\'self\', \'session\', \'global_step_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\', \'global_step_value\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-chief-session-creator.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-chief-session-creator.pbtxt new file mode 100644 index 00000000000..abbe273be32 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-chief-session-creator.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.train.ChiefSessionCreator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scaffold\', \'master\', \'config\', \'checkpoint_dir\', \'checkpoint_filename_with_path\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "create_session" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt new file mode 100644 index 00000000000..93ff856b09d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.train.ClusterDef" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "JOB_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-spec.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-cluster-spec.pbtxt new file mode 100644 index 00000000000..1658b15a5f8 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-cluster-spec.pbtxt @@ -0,0 +1,37 @@ +path: "tensorflow.train.ClusterSpec" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "jobs" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cluster\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_cluster_def" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_dict" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "job_tasks" + argspec: "args=[\'self\', \'job_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "num_tasks" + argspec: "args=[\'self\', \'job_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "task_address" + argspec: "args=[\'self\', \'job_name\', \'task_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "task_indices" + argspec: "args=[\'self\', \'job_name\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-coordinator.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-coordinator.pbtxt new file mode 100644 index 00000000000..11277f077ee --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-coordinator.pbtxt @@ -0,0 +1,45 @@ +path: "tensorflow.train.Coordinator" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "joined" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'clean_stop_exception_types\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "clear_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "join" + argspec: "args=[\'self\', \'threads\', \'stop_grace_period_secs\', \'ignore_live_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'120\', \'False\'], " + } + member_method { + name: "raise_requested_exception" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "register_thread" + argspec: "args=[\'self\', \'thread\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "request_stop" + argspec: "args=[\'self\', \'ex\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "should_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop_on_exception" + argspec: "args=[], varargs=args, keywords=kwds, defaults=None" + } + member_method { + name: "wait_for_stop" + argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt new file mode 100644 index 00000000000..f7215a20372 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.train.Example" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FEATURES_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt new file mode 100644 index 00000000000..737acbe07c9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt @@ -0,0 +1,25 @@ +path: "tensorflow.train.ExponentialMovingAverage" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'decay\', \'num_updates\', \'zero_debias\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'ExponentialMovingAverage\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "average" + argspec: "args=[\'self\', \'var\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "average_name" + argspec: "args=[\'self\', \'var\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "variables_to_restore" + argspec: "args=[\'self\', \'moving_avg_variables\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt new file mode 100644 index 00000000000..3ad98354d69 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.train.FeatureList" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FEATURE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt new file mode 100644 index 00000000000..cd171f4ca3e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.train.FeatureLists.FeatureListEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt new file mode 100644 index 00000000000..3d95017d584 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.train.FeatureLists" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FEATURE_LIST_FIELD_NUMBER" + mtype: "" + } + member { + name: "FeatureListEntry" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt new file mode 100644 index 00000000000..9cca132bba9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt @@ -0,0 +1,88 @@ +path: "tensorflow.train.Feature" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "BYTES_LIST_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FLOAT_LIST_FIELD_NUMBER" + mtype: "" + } + member { + name: "INT64_LIST_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt new file mode 100644 index 00000000000..858aee03415 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.train.Features.FeatureEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt new file mode 100644 index 00000000000..49cd12153bf --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.train.Features" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FEATURE_FIELD_NUMBER" + mtype: "" + } + member { + name: "FeatureEntry" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feed-fn-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feed-fn-hook.pbtxt new file mode 100644 index 00000000000..7bec4d032ce --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-feed-fn-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.FeedFnHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feed_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-final-ops-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-final-ops-hook.pbtxt new file mode 100644 index 00000000000..31cf9aaeb2c --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-final-ops-hook.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.train.FinalOpsHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "final_ops_values" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'final_ops\', \'final_ops_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt new file mode 100644 index 00000000000..e3f01334b54 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.train.FloatList" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt new file mode 100644 index 00000000000..2dc11df57b6 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.FtrlOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-global-step-waiter-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-global-step-waiter-hook.pbtxt new file mode 100644 index 00000000000..147448618e2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-global-step-waiter-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.GlobalStepWaiterHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'wait_until_step\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt new file mode 100644 index 00000000000..bdd4c525685 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.GradientDescentOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'GradientDescent\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt new file mode 100644 index 00000000000..8917dc122cf --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt @@ -0,0 +1,80 @@ +path: "tensorflow.train.Int64List" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt new file mode 100644 index 00000000000..ac6d81541a4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.train.JobDef.TasksEntry" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "KEY_FIELD_NUMBER" + mtype: "" + } + member { + name: "VALUE_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt new file mode 100644 index 00000000000..ce34537fa13 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt @@ -0,0 +1,88 @@ +path: "tensorflow.train.JobDef" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "TASKS_FIELD_NUMBER" + mtype: "" + } + member { + name: "TasksEntry" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-logging-tensor-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-logging-tensor-hook.pbtxt new file mode 100644 index 00000000000..9801c05df18 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-logging-tensor-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.LoggingTensorHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'tensors\', \'every_n_iter\', \'every_n_secs\', \'at_end\', \'formatter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-looper-thread.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-looper-thread.pbtxt new file mode 100644 index 00000000000..c61859004e8 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-looper-thread.pbtxt @@ -0,0 +1,73 @@ +path: "tensorflow.train.LooperThread" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "daemon" + mtype: "" + } + member { + name: "ident" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'coord\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "getName" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "isAlive" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "isDaemon" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_alive" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "join" + argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "loop" + argspec: "args=[\'coord\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "run" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run_loop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "setDaemon" + argspec: "args=[\'self\', \'daemonic\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "setName" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start_loop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop_loop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt new file mode 100644 index 00000000000..7cf5488a15e --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.MomentumOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'momentum\', \'use_locking\', \'name\', \'use_nesterov\'], varargs=None, keywords=None, defaults=[\'False\', \'Momentum\', \'False\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-monitored-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-monitored-session.pbtxt new file mode 100644 index 00000000000..3a5cc015b4d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-monitored-session.pbtxt @@ -0,0 +1,26 @@ +path: "tensorflow.train.MonitoredSession" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'session_creator\', \'hooks\', \'stop_grace_period_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'120\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run" + argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "should_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-nan-loss-during-training-error.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-nan-loss-during-training-error.pbtxt new file mode 100644 index 00000000000..25fd5e75a79 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-nan-loss-during-training-error.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.train.NanLossDuringTrainingError" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-nan-tensor-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-nan-tensor-hook.pbtxt new file mode 100644 index 00000000000..7d1c89f9b37 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-nan-tensor-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.NanTensorHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'loss_tensor\', \'fail_on_nan_loss\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt new file mode 100644 index 00000000000..20b0c4d1b56 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt @@ -0,0 +1,45 @@ +path: "tensorflow.train.Optimizer" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt new file mode 100644 index 00000000000..571d846b6c5 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.ProximalAdagradOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.1\', \'0.0\', \'0.0\', \'False\', \'ProximalAdagrad\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt new file mode 100644 index 00000000000..1feb136e7f7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.ProximalGradientDescentOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.0\', \'False\', \'ProximalGradientDescent\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-queue-runner.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-queue-runner.pbtxt new file mode 100644 index 00000000000..d84d0058eea --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-queue-runner.pbtxt @@ -0,0 +1,49 @@ +path: "tensorflow.train.QueueRunner" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "cancel_op" + mtype: "" + } + member { + name: "close_op" + mtype: "" + } + member { + name: "enqueue_ops" + mtype: "" + } + member { + name: "exceptions_raised" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "queue" + mtype: "" + } + member { + name: "queue_closed_exception_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'queue\', \'enqueue_ops\', \'close_op\', \'cancel_op\', \'queue_closed_exception_types\', \'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "create_threads" + argspec: "args=[\'self\', \'sess\', \'coord\', \'daemon\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], " + } + member_method { + name: "from_proto" + argspec: "args=[\'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt new file mode 100644 index 00000000000..2aa4ae6d2d2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.train.RMSPropOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'decay\', \'momentum\', \'epsilon\', \'use_locking\', \'centered\', \'name\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.0\', \'1e-10\', \'False\', \'False\', \'RMSProp\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt new file mode 100644 index 00000000000..84498a64f5b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt @@ -0,0 +1,120 @@ +path: "tensorflow.train.SaverDef" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "CheckpointFormatVersion" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FILENAME_TENSOR_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "KEEP_CHECKPOINT_EVERY_N_HOURS_FIELD_NUMBER" + mtype: "" + } + member { + name: "LEGACY" + mtype: "" + } + member { + name: "MAX_TO_KEEP_FIELD_NUMBER" + mtype: "" + } + member { + name: "RESTORE_OP_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "SAVE_TENSOR_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "SHARDED_FIELD_NUMBER" + mtype: "" + } + member { + name: "V1" + mtype: "" + } + member { + name: "V2" + mtype: "" + } + member { + name: "VERSION_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt new file mode 100644 index 00000000000..04c11712cd4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt @@ -0,0 +1,53 @@ +path: "tensorflow.train.Saver" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "last_checkpoints" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'var_list\', \'reshape\', \'sharded\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'name\', \'restore_sequentially\', \'saver_def\', \'builder\', \'defer_build\', \'allow_empty\', \'write_version\', \'pad_step_number\', \'save_relative_paths\', \'filename\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\', \'5\', \'10000.0\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\', \'2\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "as_saver_def" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "export_meta_graph" + argspec: "args=[\'self\', \'filename\', \'collection_list\', \'as_text\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\', \'False\'], " + } + member_method { + name: "from_proto" + argspec: "args=[\'saver_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "recover_last_checkpoints" + argspec: "args=[\'self\', \'checkpoint_paths\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "restore" + argspec: "args=[\'self\', \'sess\', \'save_path\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'sess\', \'save_path\', \'global_step\', \'latest_filename\', \'meta_graph_suffix\', \'write_meta_graph\', \'write_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'meta\', \'True\', \'True\'], " + } + member_method { + name: "set_last_checkpoints" + argspec: "args=[\'self\', \'last_checkpoints\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_last_checkpoints_with_time" + argspec: "args=[\'self\', \'last_checkpoints_with_time\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt new file mode 100644 index 00000000000..62b956c5ef7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt @@ -0,0 +1,49 @@ +path: "tensorflow.train.Scaffold" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "init_feed_dict" + mtype: "" + } + member { + name: "init_fn" + mtype: "" + } + member { + name: "init_op" + mtype: "" + } + member { + name: "local_init_op" + mtype: "" + } + member { + name: "ready_for_local_init_op" + mtype: "" + } + member { + name: "ready_op" + mtype: "" + } + member { + name: "saver" + mtype: "" + } + member { + name: "summary_op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "finalize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_or_default" + argspec: "args=[\'arg_name\', \'collection_key\', \'default_constructor\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-second-or-step-timer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-second-or-step-timer.pbtxt new file mode 100644 index 00000000000..3c5a6ac13cc --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-second-or-step-timer.pbtxt @@ -0,0 +1,26 @@ +path: "tensorflow.train.SecondOrStepTimer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'every_secs\', \'every_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "last_triggered_step" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "should_trigger_for_step" + argspec: "args=[\'self\', \'step\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update_last_triggered_step" + argspec: "args=[\'self\', \'step\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt new file mode 100644 index 00000000000..9ab95537021 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt @@ -0,0 +1,84 @@ +path: "tensorflow.train.SequenceExample" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "CONTEXT_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "FEATURE_LISTS_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt new file mode 100644 index 00000000000..af0a3b73cc2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt @@ -0,0 +1,96 @@ +path: "tensorflow.train.ServerDef" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "CLUSTER_FIELD_NUMBER" + mtype: "" + } + member { + name: "DEFAULT_SESSION_CONFIG_FIELD_NUMBER" + mtype: "" + } + member { + name: "DESCRIPTOR" + mtype: "" + } + member { + name: "Extensions" + mtype: "" + } + member { + name: "JOB_NAME_FIELD_NUMBER" + mtype: "" + } + member { + name: "PROTOCOL_FIELD_NUMBER" + mtype: "" + } + member { + name: "TASK_INDEX_FIELD_NUMBER" + mtype: "" + } + member_method { + name: "ByteSize" + } + member_method { + name: "Clear" + } + member_method { + name: "ClearExtension" + } + member_method { + name: "ClearField" + } + member_method { + name: "CopyFrom" + } + member_method { + name: "DiscardUnknownFields" + } + member_method { + name: "FindInitializationErrors" + } + member_method { + name: "FromString" + } + member_method { + name: "HasExtension" + } + member_method { + name: "HasField" + } + member_method { + name: "IsInitialized" + } + member_method { + name: "ListFields" + } + member_method { + name: "MergeFrom" + } + member_method { + name: "MergeFromString" + } + member_method { + name: "ParseFromString" + } + member_method { + name: "RegisterExtension" + } + member_method { + name: "SerializePartialToString" + } + member_method { + name: "SerializeToString" + } + member_method { + name: "SetInParent" + } + member_method { + name: "WhichOneof" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-server.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-server.pbtxt new file mode 100644 index 00000000000..9b8f185f5b6 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-server.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.train.Server" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "server_def" + mtype: "" + } + member { + name: "target" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'server_or_cluster_def\', \'job_name\', \'task_index\', \'protocol\', \'config\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "create_local_server" + argspec: "args=[\'config\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "join" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-creator.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-session-creator.pbtxt new file mode 100644 index 00000000000..beb232715f7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-session-creator.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.train.SessionCreator" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "create_session" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt new file mode 100644 index 00000000000..cc31bb4e4b3 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.train.SessionManager" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\'], " + } + member_method { + name: "prepare_session" + argspec: "args=[\'self\', \'master\', \'init_op\', \'saver\', \'checkpoint_dir\', \'checkpoint_filename_with_path\', \'wait_for_checkpoint\', \'max_wait_secs\', \'config\', \'init_feed_dict\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'7200\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "recover_session" + argspec: "args=[\'self\', \'master\', \'saver\', \'checkpoint_dir\', \'checkpoint_filename_with_path\', \'wait_for_checkpoint\', \'max_wait_secs\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'7200\', \'None\'], " + } + member_method { + name: "wait_for_session" + argspec: "args=[\'self\', \'master\', \'config\', \'max_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'inf\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-run-args.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-session-run-args.pbtxt new file mode 100644 index 00000000000..442990893e3 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-session-run-args.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.train.SessionRunArgs" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "feed_dict" + mtype: "" + } + member { + name: "fetches" + mtype: "" + } + member { + name: "options" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-run-context.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-session-run-context.pbtxt new file mode 100644 index 00000000000..d5adb15c95f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-session-run-context.pbtxt @@ -0,0 +1,25 @@ +path: "tensorflow.train.SessionRunContext" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "original_args" + mtype: "" + } + member { + name: "session" + mtype: "" + } + member { + name: "stop_requested" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'original_args\', \'session\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "request_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-run-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-session-run-hook.pbtxt new file mode 100644 index 00000000000..db1aa24acf0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-session-run-hook.pbtxt @@ -0,0 +1,28 @@ +path: "tensorflow.train.SessionRunHook" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-run-values.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-session-run-values.pbtxt new file mode 100644 index 00000000000..0b401d59c40 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-session-run-values.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.train.SessionRunValues" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "options" + mtype: "" + } + member { + name: "results" + mtype: "" + } + member { + name: "run_metadata" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.pbtxt new file mode 100644 index 00000000000..62bfdab40bb --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.SingularMonitoredSession" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'hooks\', \'scaffold\', \'master\', \'config\', \'checkpoint_dir\', \'stop_grace_period_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\', \'None\', \'None\', \'120\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "raw_session" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run" + argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "should_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-step-counter-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-step-counter-hook.pbtxt new file mode 100644 index 00000000000..13261f6dde1 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-step-counter-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.StepCounterHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'every_n_steps\', \'every_n_secs\', \'output_dir\', \'summary_writer\'], varargs=None, keywords=None, defaults=[\'100\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-stop-at-step-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-stop-at-step-hook.pbtxt new file mode 100644 index 00000000000..e388599b0bf --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-stop-at-step-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.StopAtStepHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_steps\', \'last_step\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-summary-saver-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-summary-saver-hook.pbtxt new file mode 100644 index 00000000000..697c3667b09 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-summary-saver-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.SummarySaverHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'save_steps\', \'save_secs\', \'output_dir\', \'summary_writer\', \'scaffold\', \'summary_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt new file mode 100644 index 00000000000..cc9bd5c136b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt @@ -0,0 +1,153 @@ +path: "tensorflow.train.Supervisor" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "USE_DEFAULT" + mtype: "" + } + member { + name: "coord" + mtype: "" + } + member { + name: "global_step" + mtype: "" + } + member { + name: "init_feed_dict" + mtype: "" + } + member { + name: "init_op" + mtype: "" + } + member { + name: "is_chief" + mtype: "" + } + member { + name: "ready_for_local_init_op" + mtype: "" + } + member { + name: "ready_op" + mtype: "" + } + member { + name: "save_model_secs" + mtype: "" + } + member { + name: "save_path" + mtype: "" + } + member { + name: "save_summaries_secs" + mtype: "" + } + member { + name: "saver" + mtype: "" + } + member { + name: "session_manager" + mtype: "" + } + member { + name: "summary_op" + mtype: "" + } + member { + name: "summary_writer" + mtype: "" + } + member_method { + name: "Loop" + argspec: "args=[\'self\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "PrepareSession" + argspec: "args=[\'self\', \'master\', \'config\', \'wait_for_checkpoint\', \'max_wait_secs\', \'start_standard_services\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'False\', \'7200\', \'True\'], " + } + member_method { + name: "RequestStop" + argspec: "args=[\'self\', \'ex\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ShouldStop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "StartQueueRunners" + argspec: "args=[\'self\', \'sess\', \'queue_runners\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "StartStandardServices" + argspec: "args=[\'self\', \'sess\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Stop" + argspec: "args=[\'self\', \'threads\', \'close_summary_writer\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "StopOnException" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "SummaryComputed" + argspec: "args=[\'self\', \'sess\', \'summary\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "WaitForStop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\'], " + } + member_method { + name: "loop" + argspec: "args=[\'self\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "managed_session" + argspec: "args=[], varargs=args, keywords=kwds, defaults=None" + } + member_method { + name: "prepare_or_wait_for_session" + argspec: "args=[\'self\', \'master\', \'config\', \'wait_for_checkpoint\', \'max_wait_secs\', \'start_standard_services\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'False\', \'7200\', \'True\'], " + } + member_method { + name: "request_stop" + argspec: "args=[\'self\', \'ex\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "should_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start_queue_runners" + argspec: "args=[\'self\', \'sess\', \'queue_runners\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "start_standard_services" + argspec: "args=[\'self\', \'sess\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop" + argspec: "args=[\'self\', \'threads\', \'close_summary_writer\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "stop_on_exception" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary_computed" + argspec: "args=[\'self\', \'sess\', \'summary\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "wait_for_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt new file mode 100644 index 00000000000..915d8501af0 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.train.SyncReplicasOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'opt\', \'replicas_to_aggregate\', \'total_num_replicas\', \'variable_averages\', \'variables_to_average\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'sync_replicas\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "get_chief_queue_runner" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_init_tokens_op" + argspec: "args=[\'self\', \'num_tokens\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "make_session_run_hook" + argspec: "args=[\'self\', \'is_chief\', \'num_tokens\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-worker-session-creator.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-worker-session-creator.pbtxt new file mode 100644 index 00000000000..140407651a9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.-worker-session-creator.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.train.WorkerSessionCreator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scaffold\', \'master\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\'], " + } + member_method { + name: "create_session" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt new file mode 100644 index 00000000000..58fd5760c11 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -0,0 +1,407 @@ +path: "tensorflow.train" +tf_module { + member { + name: "AdadeltaOptimizer" + mtype: "" + } + member { + name: "AdagradDAOptimizer" + mtype: "" + } + member { + name: "AdagradOptimizer" + mtype: "" + } + member { + name: "AdamOptimizer" + mtype: "" + } + member { + name: "BytesList" + mtype: "" + } + member { + name: "CheckpointSaverHook" + mtype: "" + } + member { + name: "CheckpointSaverListener" + mtype: "" + } + member { + name: "ChiefSessionCreator" + mtype: "" + } + member { + name: "ClusterDef" + mtype: "" + } + member { + name: "ClusterSpec" + mtype: "" + } + member { + name: "Coordinator" + mtype: "" + } + member { + name: "Example" + mtype: "" + } + member { + name: "ExponentialMovingAverage" + mtype: "" + } + member { + name: "Feature" + mtype: "" + } + member { + name: "FeatureList" + mtype: "" + } + member { + name: "FeatureLists" + mtype: "" + } + member { + name: "Features" + mtype: "" + } + member { + name: "FeedFnHook" + mtype: "" + } + member { + name: "FinalOpsHook" + mtype: "" + } + member { + name: "FloatList" + mtype: "" + } + member { + name: "FtrlOptimizer" + mtype: "" + } + member { + name: "GlobalStepWaiterHook" + mtype: "" + } + member { + name: "GradientDescentOptimizer" + mtype: "" + } + member { + name: "Int64List" + mtype: "" + } + member { + name: "JobDef" + mtype: "" + } + member { + name: "LoggingTensorHook" + mtype: "" + } + member { + name: "LooperThread" + mtype: "" + } + member { + name: "MomentumOptimizer" + mtype: "" + } + member { + name: "MonitoredSession" + mtype: "" + } + member { + name: "NanLossDuringTrainingError" + mtype: "" + } + member { + name: "NanTensorHook" + mtype: "" + } + member { + name: "Optimizer" + mtype: "" + } + member { + name: "ProximalAdagradOptimizer" + mtype: "" + } + member { + name: "ProximalGradientDescentOptimizer" + mtype: "" + } + member { + name: "QueueRunner" + mtype: "" + } + member { + name: "RMSPropOptimizer" + mtype: "" + } + member { + name: "Saver" + mtype: "" + } + member { + name: "SaverDef" + mtype: "" + } + member { + name: "Scaffold" + mtype: "" + } + member { + name: "SecondOrStepTimer" + mtype: "" + } + member { + name: "SequenceExample" + mtype: "" + } + member { + name: "Server" + mtype: "" + } + member { + name: "ServerDef" + mtype: "" + } + member { + name: "SessionCreator" + mtype: "" + } + member { + name: "SessionManager" + mtype: "" + } + member { + name: "SessionRunArgs" + mtype: "" + } + member { + name: "SessionRunContext" + mtype: "" + } + member { + name: "SessionRunHook" + mtype: "" + } + member { + name: "SessionRunValues" + mtype: "" + } + member { + name: "SingularMonitoredSession" + mtype: "" + } + member { + name: "StepCounterHook" + mtype: "" + } + member { + name: "StopAtStepHook" + mtype: "" + } + member { + name: "SummarySaverHook" + mtype: "" + } + member { + name: "Supervisor" + mtype: "" + } + member { + name: "SyncReplicasOptimizer" + mtype: "" + } + member { + name: "WorkerSessionCreator" + mtype: "" + } + member { + name: "queue_runner" + mtype: "" + } + member_method { + name: "MonitoredTrainingSession" + argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'600\', \'100\', \'None\', \'None\', \'120\', \'100\'], " + } + member_method { + name: "NewCheckpointReader" + argspec: "args=[\'filepattern\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_queue_runner" + argspec: "args=[\'qr\', \'collection\'], varargs=None, keywords=None, defaults=[\'queue_runners\'], " + } + member_method { + name: "assert_global_step" + argspec: "args=[\'global_step_tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "basic_train_loop" + argspec: "args=[\'supervisor\', \'train_step_fn\', \'args\', \'kwargs\', \'master\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\'], " + } + member_method { + name: "batch" + argspec: "args=[\'tensors\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "batch_join" + argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "checkpoint_exists" + argspec: "args=[\'checkpoint_prefix\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "create_global_step" + argspec: "args=[\'graph\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "do_quantize_training_on_graphdef" + argspec: "args=[\'input_graph\', \'num_bits\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "exponential_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "export_meta_graph" + argspec: "args=[\'filename\', \'meta_info_def\', \'graph_def\', \'saver_def\', \'collection_list\', \'as_text\', \'graph\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\'], " + } + member_method { + name: "generate_checkpoint_state_proto" + argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_checkpoint_mtimes" + argspec: "args=[\'checkpoint_prefixes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_checkpoint_state" + argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_global_step" + argspec: "args=[\'graph\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_or_create_global_step" + argspec: "args=[\'graph\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "global_step" + argspec: "args=[\'sess\', \'global_step_tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "import_meta_graph" + argspec: "args=[\'meta_graph_or_file\', \'clear_devices\', \'import_scope\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\'], " + } + member_method { + name: "input_producer" + argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "inverse_time_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "limit_epochs" + argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "match_filenames_once" + argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "maybe_batch" + argspec: "args=[\'tensors\', \'keep_input\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "maybe_batch_join" + argspec: "args=[\'tensors_list\', \'keep_input\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "maybe_shuffle_batch" + argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "maybe_shuffle_batch_join" + argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "natural_exp_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "piecewise_constant" + argspec: "args=[\'x\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polynomial_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], " + } + member_method { + name: "range_input_producer" + argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], " + } + member_method { + name: "replica_device_setter" + argspec: "args=[\'ps_tasks\', \'ps_device\', \'worker_device\', \'merge_devices\', \'cluster\', \'ps_ops\', \'ps_strategy\'], varargs=None, keywords=None, defaults=[\'0\', \'/job:ps\', \'/job:worker\', \'True\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sdca_fprint" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sdca_optimizer" + argspec: "args=[\'sparse_example_indices\', \'sparse_feature_indices\', \'sparse_feature_values\', \'dense_features\', \'example_weights\', \'example_labels\', \'sparse_indices\', \'sparse_weights\', \'dense_weights\', \'example_state_data\', \'loss_type\', \'l1\', \'l2\', \'num_loss_partitions\', \'num_inner_iterations\', \'adaptative\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "sdca_shrink_l1" + argspec: "args=[\'weights\', \'l1\', \'l2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "shuffle_batch" + argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "shuffle_batch_join" + argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "slice_input_producer" + argspec: "args=[\'tensor_list\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], " + } + member_method { + name: "start_queue_runners" + argspec: "args=[\'sess\', \'coord\', \'daemon\', \'start\', \'collection\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'True\', \'queue_runners\'], " + } + member_method { + name: "string_input_producer" + argspec: "args=[\'string_tensor\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "summary_iterator" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update_checkpoint_state" + argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "write_graph" + argspec: "args=[\'graph_or_graph_def\', \'logdir\', \'name\', \'as_text\'], varargs=None, keywords=None, defaults=[\'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.queue_runner.-queue-runner.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.queue_runner.-queue-runner.pbtxt new file mode 100644 index 00000000000..23d402de308 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.queue_runner.-queue-runner.pbtxt @@ -0,0 +1,49 @@ +path: "tensorflow.train.queue_runner.QueueRunner" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "cancel_op" + mtype: "" + } + member { + name: "close_op" + mtype: "" + } + member { + name: "enqueue_ops" + mtype: "" + } + member { + name: "exceptions_raised" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "queue" + mtype: "" + } + member { + name: "queue_closed_exception_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'queue\', \'enqueue_ops\', \'close_op\', \'cancel_op\', \'queue_closed_exception_types\', \'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "create_threads" + argspec: "args=[\'self\', \'sess\', \'coord\', \'daemon\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], " + } + member_method { + name: "from_proto" + argspec: "args=[\'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.queue_runner.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.queue_runner.pbtxt new file mode 100644 index 00000000000..6e2d0430496 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.train.queue_runner.pbtxt @@ -0,0 +1,15 @@ +path: "tensorflow.train.queue_runner" +tf_module { + member { + name: "QueueRunner" + mtype: "" + } + member_method { + name: "add_queue_runner" + argspec: "args=[\'qr\', \'collection\'], varargs=None, keywords=None, defaults=[\'queue_runners\'], " + } + member_method { + name: "start_queue_runners" + argspec: "args=[\'sess\', \'coord\', \'daemon\', \'start\', \'collection\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'True\', \'queue_runners\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt new file mode 100644 index 00000000000..c1e1c230a9f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.truncated_normal_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt new file mode 100644 index 00000000000..e1b18dc92fb --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.uniform_unit_scaling_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'factor\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt new file mode 100644 index 00000000000..e229b02ceec --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.zeros_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/lib/BUILD b/tensorflow/tools/api/lib/BUILD new file mode 100644 index 00000000000..cdfa0e7be52 --- /dev/null +++ b/tensorflow/tools/api/lib/BUILD @@ -0,0 +1,39 @@ +# Helper libraries for TensorFlow API compatibility test. + +package( + default_visibility = ["//tensorflow/tools/api:__subpackages__"], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", +) + +tf_proto_library( + name = "api_objects_proto", + srcs = ["api_objects.proto"], +) + +py_library( + name = "python_object_to_proto_visitor", + srcs = ["python_object_to_proto_visitor.py"], + srcs_version = "PY2AND3", + deps = [ + ":api_objects_proto_py", + "//tensorflow/tools/common:traverse", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tools/api/lib/api_objects.proto b/tensorflow/tools/api/lib/api_objects.proto new file mode 100644 index 00000000000..0966a5f1d53 --- /dev/null +++ b/tensorflow/tools/api/lib/api_objects.proto @@ -0,0 +1,31 @@ +syntax = "proto2"; + +package third_party.tensorflow.tools.api; + +message TFAPIMember { + optional string name = 1; + optional string mtype = 2; +}; + +message TFAPIMethod { + optional string name = 1; + optional string path = 2; + optional string argspec = 3; +}; + +message TFAPIModule { + repeated TFAPIMember member = 1; + repeated TFAPIMethod member_method = 2; +}; + +message TFAPIClass { + repeated string is_instance = 1; + repeated TFAPIMember member = 2; + repeated TFAPIMethod member_method = 3; +}; + +message TFAPIObject { + optional string path = 1; + optional TFAPIModule tf_module = 2; + optional TFAPIClass tf_class = 3; +}; diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py new file mode 100644 index 00000000000..43ba52f9834 --- /dev/null +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -0,0 +1,173 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== +"""A visitor class that generates protobufs for each pyton object.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect +from tensorflow.tools.api.lib import api_objects_pb2 + +# Following object need to be handled individually. +_CORNER_CASES = { + '': {'tools': {}}, + 'test.TestCase': {}, + 'test.TestCase.failureException': {}, +} + + +def _SanitizedArgSpec(obj): + """Get an ArgSpec string that is free of addresses. + + We have callables as function arg defaults. This results in addresses in + getargspec output. This function returns a sanitized string list of base + classes. + + Args: + obj: A python routine for us the create the sanitized arspec of. + + Returns: + string, a string representation of the argspec. + """ + output_string = '' + unsanitized_arg_spec = tf_inspect.getargspec(obj) + + for clean_attr in ('args', 'varargs', 'keywords'): + output_string += '%s=%s, ' % (clean_attr, + getattr(unsanitized_arg_spec, clean_attr)) + + if unsanitized_arg_spec.defaults: + sanitized_defaults = [] + for val in unsanitized_arg_spec.defaults: + str_val = str(val) + # Sanitize argspecs that have hex code in them. + if ' at 0x' in str_val: + sanitized_defaults.append('%s instance>' % str_val.split(' at ')[0]) + else: + sanitized_defaults.append(str_val) + + output_string += 'defaults=%s, ' % sanitized_defaults + + else: + output_string += 'defaults=None' + + return output_string + + +def _SanitizedMRO(obj): + """Get a list of superclasses with minimal amount of non-TF classes. + + Based on many parameters like python version, OS, protobuf implementation + or changes in google core libraries the list of superclasses of a class + can change. We only return the first non-TF class to be robust to non API + affecting changes. The Method Resolution Order returned by `tf_inspect.getmro` + is still maintained in the return value. + + Args: + obj: A python routine for us the create the sanitized arspec of. + + Returns: + list of strings, string representation of the class names. + """ + return_list = [] + for cls in tf_inspect.getmro(obj): + str_repr = str(cls) + return_list.append(str_repr) + if 'tensorflow' not in str_repr: + break + + # Hack - tensorflow.test.StubOutForTesting may or may not be type + # depending on the environment. To avoid inconsistency, break after we add + # StubOutForTesting to the return_list. + if 'StubOutForTesting' in str_repr: + break + + return return_list + + +class PythonObjectToProtoVisitor(object): + """A visitor that summarizes given python objects as protobufs.""" + + def __init__(self): + # A dict to store all protocol buffers. + # Keyed by "path" to the object. + self._protos = {} + + def GetProtos(self): + """Return the list of protos stored.""" + return self._protos + + def __call__(self, path, parent, children): + # The path to the object. + lib_path = 'tensorflow.%s' % path if path else 'tensorflow' + + # A small helper method to construct members(children) protos. + def _AddMember(member_name, member_obj, proto): + """Add the child object to the object being constructed.""" + _, member_obj = tf_decorator.unwrap(member_obj) + if member_name == '__init__' or not member_name.startswith('_'): + if tf_inspect.isroutine(member_obj): + new_method = proto.member_method.add() + new_method.name = member_name + # If member_obj is a python builtin, there is no way to get its + # argspec, because it is implemented on the C side. It also has no + # func_code. + if getattr(member_obj, 'func_code', None): + new_method.argspec = _SanitizedArgSpec(member_obj) + else: + new_member = proto.member.add() + new_member.name = member_name + new_member.mtype = str(type(member_obj)) + + parent_corner_cases = _CORNER_CASES.get(path, {}) + + if path not in _CORNER_CASES or parent_corner_cases: + # Decide if we have a module or a class. + if tf_inspect.ismodule(parent): + # Create a module object. + module_obj = api_objects_pb2.TFAPIModule() + for name, child in children: + if name in parent_corner_cases: + # If we have an empty entry, skip this object. + if parent_corner_cases[name]: + module_obj.member.add(**(parent_corner_cases[name])) + else: + _AddMember(name, child, module_obj) + + # Store the constructed module object. + self._protos[lib_path] = api_objects_pb2.TFAPIObject( + path=lib_path, tf_module=module_obj) + elif tf_inspect.isclass(parent): + # Construct a class. + class_obj = api_objects_pb2.TFAPIClass() + class_obj.is_instance.extend(_SanitizedMRO(parent)) + for name, child in children: + if name in parent_corner_cases: + # If we have an empty entry, skip this object. + if parent_corner_cases[name]: + module_obj.member.add(**(parent_corner_cases[name])) + else: + _AddMember(name, child, class_obj) + + # Store the constructed class object. + self._protos[lib_path] = api_objects_pb2.TFAPIObject( + path=lib_path, tf_class=class_obj) + else: + logging.error('Illegal call to ApiProtoDump::_py_obj_to_proto.' + 'Object is neither a module nor a class: %s', path) diff --git a/tensorflow/tools/api/tests/API_UPDATE_WARNING.txt b/tensorflow/tools/api/tests/API_UPDATE_WARNING.txt new file mode 100644 index 00000000000..54b0cfcb3c1 --- /dev/null +++ b/tensorflow/tools/api/tests/API_UPDATE_WARNING.txt @@ -0,0 +1,7 @@ +Golden file update requested! +All test failures have been skipped, see the logs for detected diffs. +This test is now going to write new golden files. +Make sure to package the updates together with your change. + +You will need an explicit API approval. This may take longer than a normal +review. diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD new file mode 100644 index 00000000000..8421d8fce28 --- /dev/null +++ b/tensorflow/tools/api/tests/BUILD @@ -0,0 +1,44 @@ +# TensorFlow API backwards compatibility tests. + +package( + default_visibility = ["//tensorflow/tools/api:__subpackages__"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files([ + "README.txt", + "API_UPDATE_WARNING.txt", +]) + +py_test( + name = "api_compatibility_test", + size = "small", + srcs = ["api_compatibility_test.py"], + data = [ + "//tensorflow/tools/api/golden:api_golden", + "//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt", + "//tensorflow/tools/api/tests:README.txt", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:platform", + "//tensorflow/tools/api/lib:python_object_to_proto_visitor", + "//tensorflow/tools/common:public_api", + "//tensorflow/tools/common:traverse", + "@protobuf//:protobuf_python", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tools/api/tests/README.txt b/tensorflow/tools/api/tests/README.txt new file mode 100644 index 00000000000..3463eeec6fe --- /dev/null +++ b/tensorflow/tools/api/tests/README.txt @@ -0,0 +1,13 @@ +TensorFlow API backwards compatibility test +This test ensures all changes to the public API of TensorFlow are intended. + +If this test fails, it means a change has been made to the public API. Backwards +incompatible changes are not allowed. You can run the test as follows to update +test goldens and package them with your change. + + $ bazel build tensorflow/tools/api/tests:api_compatibility_test + $ bazel-bin/tensorflow/tools/api/tests/api_compatibility_test \ + --update_goldens True + +You will need an API approval to make changes to the public TensorFlow API. This +includes additions to the API. diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py new file mode 100644 index 00000000000..1ffa8fc26c0 --- /dev/null +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -0,0 +1,242 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== +"""TensorFlow API compatibility tests. + +This test ensures all changes to the public API of TensorFlow are intended. + +If this test fails, it means a change has been made to the public API. Backwards +incompatible changes are not allowed. You can run the test with +"--update_goldens" flag set to "True" to update goldens when making changes to +the public TF python API. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import re +import sys +import unittest + +import tensorflow as tf + +from google.protobuf import text_format + +from tensorflow.python.lib.io import file_io +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.tools.api.lib import api_objects_pb2 +from tensorflow.tools.api.lib import python_object_to_proto_visitor +from tensorflow.tools.common import public_api +from tensorflow.tools.common import traverse + +# FLAGS defined at the bottom: +FLAGS = None +# DEFINE_boolean, update_goldens, default False: +_UPDATE_GOLDENS_HELP = """ + Update stored golden files if API is updated. WARNING: All API changes + have to be authorized by TensorFlow leads. +""" + +# DEFINE_boolean, verbose_diffs, default False: +_VERBOSE_DIFFS_HELP = """ + If set to true, print line by line diffs on all libraries. If set to + false, only print which libraries have differences. +""" + +_API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden' +_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt' +_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt' + + +def _KeyToFilePath(key): + """From a given key, construct a filepath.""" + def _ReplaceCapsWithDash(matchobj): + match = matchobj.group(0) + return '-%s' % (match.lower()) + + case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, key) + return os.path.join(_API_GOLDEN_FOLDER, '%s.pbtxt' % case_insensitive_key) + + +def _FileNameToKey(filename): + """From a given filename, construct a key we use for api objects.""" + def _ReplaceDashWithCaps(matchobj): + match = matchobj.group(0) + return match[1].upper() + + base_filename = os.path.basename(filename) + base_filename_without_ext = os.path.splitext(base_filename)[0] + api_object_key = re.sub( + '((-[a-z]){1})', _ReplaceDashWithCaps, base_filename_without_ext) + return api_object_key + + +class ApiCompatibilityTest(test.TestCase): + + def __init__(self, *args, **kwargs): + super(ApiCompatibilityTest, self).__init__(*args, **kwargs) + + golden_update_warning_filename = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _UPDATE_WARNING_FILE) + self._update_golden_warning = file_io.read_file_to_string( + golden_update_warning_filename) + + test_readme_filename = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _TEST_README_FILE) + self._test_readme_message = file_io.read_file_to_string( + test_readme_filename) + + def _AssertProtoDictEquals(self, + expected_dict, + actual_dict, + verbose=False, + update_goldens=False): + """Diff given dicts of protobufs and report differences a readable way. + + Args: + expected_dict: a dict of TFAPIObject protos constructed from golden + files. + actual_dict: a ict of TFAPIObject protos constructed by reading from the + TF package linked to the test. + verbose: Whether to log the full diffs, or simply report which files were + different. + update_goldens: Whether to update goldens when there are diffs found. + """ + diffs = [] + verbose_diffs = [] + + expected_keys = set(expected_dict.keys()) + actual_keys = set(actual_dict.keys()) + only_in_expected = expected_keys - actual_keys + only_in_actual = actual_keys - expected_keys + all_keys = expected_keys | actual_keys + + # This will be populated below. + updated_keys = [] + + for key in all_keys: + diff_message = '' + verbose_diff_message = '' + # First check if the key is not found in one or the other. + if key in only_in_expected: + diff_message = 'Object %s expected but not found (removed).' % key + verbose_diff_message = diff_message + elif key in only_in_actual: + diff_message = 'New object %s found (added).' % key + verbose_diff_message = diff_message + else: + # Now we can run an actual proto diff. + try: + self.assertProtoEquals(expected_dict[key], actual_dict[key]) + except AssertionError as e: + updated_keys.append(key) + diff_message = 'Change detected in python object: %s.' % key + verbose_diff_message = str(e) + + # All difference cases covered above. If any difference found, add to the + # list. + if diff_message: + diffs.append(diff_message) + verbose_diffs.append(verbose_diff_message) + + # If diffs are found, handle them based on flags. + if diffs: + diff_count = len(diffs) + logging.error(self._test_readme_message) + logging.error('%d differences found between API and golden.', diff_count) + messages = verbose_diffs if verbose else diffs + for i in range(diff_count): + logging.error('Issue %d\t: %s', i + 1, messages[i]) + + if update_goldens: + # Write files if requested. + logging.warning(self._update_golden_warning) + + # If the keys are only in expected, some objects are deleted. + # Remove files. + for key in only_in_expected: + filepath = _KeyToFilePath(key) + file_io.delete_file(filepath) + + # If the files are only in actual (current library), these are new + # modules. Write them to files. Also record all updates in files. + for key in only_in_actual | set(updated_keys): + filepath = _KeyToFilePath(key) + file_io.write_string_to_file( + filepath, text_format.MessageToString(actual_dict[key])) + else: + # Fail if we cannot fix the test by updating goldens. + self.fail('%d differences found between API and golden.' % diff_count) + + else: + logging.info('No differences found between API and golden.') + + @unittest.skipUnless( + sys.version_info.major == 2 and os.uname()[0] == 'Linux', + 'API compabitility test goldens are generated using python2 on Linux.') + def testAPIBackwardsCompatibility(self): + # Extract all API stuff. + visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() + + public_api_visitor = public_api.PublicAPIVisitor(visitor) + public_api_visitor.do_not_descend_map['tf'].append('contrib') + traverse.traverse(tf, public_api_visitor) + + proto_dict = visitor.GetProtos() + + # Read all golden files. + expression = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _KeyToFilePath('*')) + golden_file_list = file_io.get_matching_files(expression) + + def _ReadFileToProto(filename): + """Read a filename, create a protobuf from its contents.""" + ret_val = api_objects_pb2.TFAPIObject() + text_format.Merge(file_io.read_file_to_string(filename), ret_val) + return ret_val + + golden_proto_dict = { + _FileNameToKey(filename): _ReadFileToProto(filename) + for filename in golden_file_list + } + + # Diff them. Do not fail if called with update. + # If the test is run to update goldens, only report diffs but do not fail. + self._AssertProtoDictEquals( + golden_proto_dict, + proto_dict, + verbose=FLAGS.verbose_diffs, + update_goldens=FLAGS.update_goldens) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP) + parser.add_argument( + '--verbose_diffs', type=bool, default=False, help=_VERBOSE_DIFFS_HELP) + FLAGS, unparsed = parser.parse_known_args() + + # Now update argv, so that unittest library does not get confused. + sys.argv = [sys.argv[0]] + unparsed + test.main() diff --git a/tensorflow/tools/benchmark/BUILD b/tensorflow/tools/benchmark/BUILD index 3b1901fd567..a2ffca97ecb 100644 --- a/tensorflow/tools/benchmark/BUILD +++ b/tensorflow/tools/benchmark/BUILD @@ -34,6 +34,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:framework_lite", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", "//tensorflow/core:test", diff --git a/tensorflow/tools/benchmark/README.md b/tensorflow/tools/benchmark/README.md index 5cb1aa6cf85..fd1bebe8359 100644 --- a/tensorflow/tools/benchmark/README.md +++ b/tensorflow/tools/benchmark/README.md @@ -9,6 +9,8 @@ both on desktop machines and on Android. ### On Android: +(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android to edit the `WORKSPACE` to configure the android NDK/SDK. + (1) build for your specific platform, e.g.: ```bash $bazel build -c opt \ diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index 180600e3b45..dfad11adf0b 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" @@ -36,6 +37,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/command_line_flags.h" @@ -109,8 +111,21 @@ void CreateTensorsFromInputInfo( InitializeTensor(input.initialization_values, &input_tensor); break; } + case DT_BOOL: { + InitializeTensor(input.initialization_values, &input_tensor); + break; + } + case DT_STRING: { + if (!input.initialization_values.empty()) { + LOG(FATAL) << "Initialization values are not supported for strings"; + } + auto type_tensor = input_tensor.flat(); + type_tensor = type_tensor.constant(""); + break; + } default: - LOG(FATAL) << "Unsupported input type: " << input.data_type; + LOG(FATAL) << "Unsupported input type: " + << DataTypeString(input.data_type); } input_tensors->push_back({input.name, input_tensor}); } @@ -195,7 +210,7 @@ Status CalculateFlops(const GraphDef& graph, Status RunBenchmark(const std::vector& inputs, const std::vector& outputs, Session* session, - StatSummarizer* stats) { + StatSummarizer* stats, int64* inference_time_us) { std::vector > input_tensors; CreateTensorsFromInputInfo(inputs, &input_tensors); @@ -204,21 +219,27 @@ Status RunBenchmark(const std::vector& inputs, tensorflow::Status s; RunOptions run_options; - run_options.set_trace_level(RunOptions::FULL_TRACE); - RunMetadata run_metadata; + if (stats != nullptr) { + run_options.set_trace_level(RunOptions::FULL_TRACE); + } + RunMetadata run_metadata; + const int64 start_time = Env::Default()->NowMicros(); s = session->Run(run_options, input_tensors, outputs, {}, &output_tensors, &run_metadata); + const int64 end_time = Env::Default()->NowMicros(); + *inference_time_us = end_time - start_time; if (!s.ok()) { LOG(ERROR) << "Error during inference: " << s; + return s; } - assert(run_metadata.has_step_stats()); - - const StepStats& step_stats = run_metadata.step_stats(); - - stats->ProcessStepStats(step_stats); + if (stats != nullptr) { + assert(run_metadata.has_step_stats()); + const StepStats& step_stats = run_metadata.step_stats(); + stats->ProcessStepStats(step_stats); + } return s; } @@ -226,15 +247,24 @@ Status RunBenchmark(const std::vector& inputs, Status TimeMultipleRuns(double sleep_seconds, int num_runs, const std::vector& inputs, const std::vector& outputs, Session* session, - StatSummarizer* stats) { + StatSummarizer* stats, int64* total_time_us) { // Convert the run_delay string into a timespec. timespec req; req.tv_sec = static_cast(sleep_seconds); req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; - LOG(INFO) << "Running benchmark"; + *total_time_us = 0; + + LOG(INFO) << "Running benchmark for " << num_runs << " iterations " + << (stats != nullptr ? "with" : "without") + << " detailed stat logging:"; + + Stat stat; for (int i = 0; i < num_runs; ++i) { - Status run_status = RunBenchmark(inputs, outputs, session, stats); + int64 time; + Status run_status = RunBenchmark(inputs, outputs, session, stats, &time); + stat.UpdateStat(time); + *total_time_us += time; if (!run_status.ok()) { LOG(INFO) << "Failed on run " << i; return run_status; @@ -244,9 +274,16 @@ Status TimeMultipleRuns(double sleep_seconds, int num_runs, // This can be helpful to determine the effect of mobile processor // scaling and thermal throttling. if (sleep_seconds > 0.0) { +#ifdef PLATFORM_WINDOWS + Sleep(sleep_seconds * 1000); +#else nanosleep(&req, nullptr); +#endif } } + std::stringstream stream; + stat.OutputToStream(&stream); + LOG(INFO) << stream.str() << std::endl; return Status::OK(); } @@ -273,6 +310,7 @@ int Main(int argc, char** argv) { bool show_type = true; bool show_summary = true; bool show_flops = false; + int warmup_runs = 2; std::vector flag_list = { Flag("graph", &graph, "graph file name"), @@ -297,10 +335,11 @@ int Main(int argc, char** argv) { Flag("show_memory", &show_memory, "whether to list stats by memory used"), Flag("memory_limit", &memory_limit, "how many items to show by memory used"), - Flag("show_type", &show_time, "whether to list stats by op type"), - Flag("show_summary", &show_time, + Flag("show_type", &show_type, "whether to list stats by op type"), + Flag("show_summary", &show_summary, "whether to show a summary of the stats"), Flag("show_flops", &show_flops, "whether to estimate the model's FLOPs"), + Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"), }; string usage = Flags::Usage(argv[0], flag_list); const bool parse_result = Flags::Parse(&argc, argv, flag_list); @@ -351,6 +390,7 @@ int Main(int argc, char** argv) { LOG(INFO) << "Benchmark name: [" << benchmark_name << "]"; LOG(INFO) << "Output prefix: [" << output_prefix << "]"; LOG(INFO) << "Show sizes: [" << show_sizes << "]"; + LOG(INFO) << "Warmup runs: [" << warmup_runs << "]"; std::unique_ptr session; std::unique_ptr stats; @@ -383,6 +423,12 @@ int Main(int argc, char** argv) { CHECK(str_util::SplitAndParseAsInts(input_layer_shapes[n], ',', &sizes)) << "Incorrect size string specified: " << input_layer_shapes[n]; for (int i = 0; i < sizes.size(); ++i) { + int32 size = sizes[i]; + if (size == -1) { + LOG(ERROR) << "Any unknown sizes in the shapes (-1's) must be replaced" + << " with the size you want to benchmark with."; + return -1; + } input.shape.AddDim(sizes[i]); } input.name = input_layers[n]; @@ -395,18 +441,48 @@ int Main(int argc, char** argv) { inputs.push_back(input); } - const int64 start_time = Env::Default()->NowMicros(); - Status time_status = - TimeMultipleRuns(sleep_seconds, num_runs, inputs, output_layers, - session.get(), stats.get()); - const int64 end_time = Env::Default()->NowMicros(); - const double wall_time = (end_time - start_time) / 1000000.0; + // If requested, run through the graph first to preinitialize everything + // before the benchmarking runs. + int64 warmup_time_us = 0; + if (warmup_runs > 0) { + Status warmup_time_status = + TimeMultipleRuns(sleep_seconds, warmup_runs, inputs, output_layers, + session.get(), nullptr, &warmup_time_us); + if (!warmup_time_status.ok()) { + LOG(ERROR) << "Timing failed with " << warmup_time_status; + return -1; + } + } - if (!time_status.ok()) { - LOG(ERROR) << "Timing failed with " << time_status; + // Capture overall inference time without stat logging overhead. This is the + // timing data that can be compared to other libaries. + int64 no_stat_time_us = 0; + Status no_stat_time_status = + TimeMultipleRuns(sleep_seconds, num_runs, inputs, output_layers, + session.get(), nullptr, &no_stat_time_us); + const double no_stat_wall_time = no_stat_time_us / 1000000.0; + if (!no_stat_time_status.ok()) { + LOG(ERROR) << "Timing failed with " << no_stat_time_status; return -1; } + // Run again to gather detailed log stats to get a better idea of where + // relative time is going within the graph. + int64 stat_time_us = 0; + Status stat_time_status = + TimeMultipleRuns(sleep_seconds, num_runs, inputs, output_layers, + session.get(), stats.get(), &stat_time_us); + if (!stat_time_status.ok()) { + LOG(ERROR) << "Timing failed with " << stat_time_status; + return -1; + } + + LOG(INFO) << "Average inference timings in us: " + << "Warmup: " + << (warmup_runs > 0 ? warmup_time_us / warmup_runs : 0) << ", " + << "no stats: " << no_stat_time_us / num_runs << ", " + << "with stats: " << stat_time_us / num_runs; + stats->PrintStepStats(); if (show_sizes) { @@ -437,7 +513,7 @@ int Main(int argc, char** argv) { pretty_flops = strings::StrCat(rounded_flops, " billion FLOPs"); } LOG(INFO) << "FLOPs estimate: " << strings::HumanReadableNum(total_flops); - const double mean_run_time = wall_time / num_runs; + const double mean_run_time = no_stat_wall_time / num_runs; LOG(INFO) << "FLOPs/second: " << strings::HumanReadableNum( static_cast(total_flops / mean_run_time)); @@ -448,15 +524,38 @@ int Main(int argc, char** argv) { int64 total_size = inputs[0].shape.num_elements(); // Throughput in MB/s - const double throughput = DataTypeSize(inputs[0].data_type) * total_size * - num_runs / static_cast(wall_time) / - (1024 * 1024); + const double throughput = + DataTypeSize(inputs[0].data_type) * total_size * num_runs / + static_cast(no_stat_wall_time) / (1024 * 1024); // Report the stats. TestReporter reporter(output_prefix, benchmark_name); - reporter.Initialize(); - reporter.Benchmark(num_runs, -1.0, wall_time, throughput); - reporter.Close(); + TF_QCHECK_OK(reporter.Initialize()); + TF_QCHECK_OK( + reporter.Benchmark(num_runs, -1.0, no_stat_wall_time, throughput)); + TF_QCHECK_OK(reporter.Close()); + + std::map node_type_map_count; + std::map node_type_map_time; + std::map node_type_map_memory; + std::map node_type_map_times_called; + + int64 accumulated_us; + stats->ComputeStatsByType(&node_type_map_count, &node_type_map_time, + &node_type_map_memory, + &node_type_map_times_called, &accumulated_us); + for (const auto& time : node_type_map_time) { + std::stringstream stream; + stream << benchmark_name << "_" << time.first; + TestReporter node_reporter(output_prefix, stream.str()); + + LOG(INFO) << "Outputting: [" << time.first << "]"; + + TF_QCHECK_OK(node_reporter.Initialize()); + TF_QCHECK_OK(node_reporter.Benchmark( + num_runs, -1.0, (time.second * num_runs) / 1000000.0f, -1.0)); + TF_QCHECK_OK(node_reporter.Close()); + } } return 0; diff --git a/tensorflow/tools/benchmark/benchmark_model.h b/tensorflow/tools/benchmark/benchmark_model.h index d2757e94fa6..b9c0a488a4b 100644 --- a/tensorflow/tools/benchmark/benchmark_model.h +++ b/tensorflow/tools/benchmark/benchmark_model.h @@ -38,13 +38,13 @@ Status InitializeSession(int num_threads, const string& graph, // Does a single run of the model that's been loaded into the given session. Status RunBenchmark(const std::vector& inputs, const std::vector& outputs, Session* session, - StatSummarizer* stats); + StatSummarizer* stats, int64* inference_time_us); // Runs the model multiple time, keeping track of timing information. Status TimeMultipleRuns(double sleep_seconds, int num_runs, const std::vector& inputs, const std::vector& outputs, Session* session, - StatSummarizer* stats); + StatSummarizer* stats, int64* total_time_us); // Handles all setup and argument parsing. int Main(int argc, char** argv); diff --git a/tensorflow/tools/benchmark/benchmark_model_test.cc b/tensorflow/tools/benchmark/benchmark_model_test.cc index 9e0a3bd9400..ee7f24c0cf7 100644 --- a/tensorflow/tools/benchmark/benchmark_model_test.cc +++ b/tensorflow/tools/benchmark/benchmark_model_test.cc @@ -61,8 +61,9 @@ TEST(BenchmarkModelTest, InitializeAndRun) { &loaded_graph_def)); std::unique_ptr stats; stats.reset(new tensorflow::StatSummarizer(*(loaded_graph_def.get()))); + int64 time; TF_ASSERT_OK(benchmark_model::TimeMultipleRuns( - 0.0, 10, {input}, {output_name}, session.get(), stats.get())); + 0.0, 10, {input}, {output_name}, session.get(), stats.get(), &time)); } } // namespace diff --git a/tensorflow/tools/ci_build/Dockerfile.android b/tensorflow/tools/ci_build/Dockerfile.android index 1d888e4eaed..c6679f78826 100644 --- a/tensorflow/tools/ci_build/Dockerfile.android +++ b/tensorflow/tools/ci_build/Dockerfile.android @@ -6,21 +6,16 @@ MAINTAINER Jan Prach COPY install/*.sh /install/ RUN /install/install_bootstrap_deb_packages.sh RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - add-apt-repository -y ppa:mc3man/trusty-media && \ add-apt-repository -y ppa:george-edison55/cmake-3.x RUN /install/install_deb_packages.sh RUN /install/install_bazel.sh -# Set up bazelrc. -COPY install/.bazelrc /root/.bazelrc -ENV BAZELRC /root/.bazelrc +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc # Install extra libraries for android sdk. -# (see http://stackoverflow.com/questions/33427893/can-not-run-android-sdk-build-tools-23-0-2-aapt) RUN apt-get update && apt-get install -y \ python-numpy \ - lib32stdc++6 \ - lib32z1 \ && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -31,28 +26,28 @@ ENV ANDROID_DEV_HOME /android RUN mkdir -p ${ANDROID_DEV_HOME} # Install Android SDK. -ENV ANDROID_SDK_FILENAME android-sdk_r24.4.1-linux.tgz -ENV ANDROID_SDK_URL http://dl.google.com/android/${ANDROID_SDK_FILENAME} +ENV ANDROID_SDK_FILENAME tools_r25.2.5-linux.zip +ENV ANDROID_SDK_URL https://dl.google.com/android/repository/${ANDROID_SDK_FILENAME} ENV ANDROID_API_LEVEL 23 -ENV ANDROID_BUILD_TOOLS_VERSION 23.0.2 +# Build Tools Version liable to change. +ENV ANDROID_BUILD_TOOLS_VERSION 25.0.2 ENV ANDROID_SDK_HOME ${ANDROID_DEV_HOME}/sdk ENV PATH ${PATH}:${ANDROID_SDK_HOME}/tools:${ANDROID_SDK_HOME}/platform-tools RUN cd ${ANDROID_DEV_HOME} && \ wget -q ${ANDROID_SDK_URL} && \ - tar -xzf ${ANDROID_SDK_FILENAME} && \ + unzip ${ANDROID_SDK_FILENAME} -d android-sdk-linux && \ rm ${ANDROID_SDK_FILENAME} && \ bash -c "ln -s ${ANDROID_DEV_HOME}/android-sdk-* ${ANDROID_SDK_HOME}" && \ echo y | android update sdk --no-ui -a --filter tools,platform-tools,android-${ANDROID_API_LEVEL},build-tools-${ANDROID_BUILD_TOOLS_VERSION} # Install Android NDK. -ENV ANDROID_NDK_FILENAME android-ndk-r10e-linux-x86_64.bin -ENV ANDROID_NDK_URL http://dl.google.com/android/ndk/${ANDROID_NDK_FILENAME} +ENV ANDROID_NDK_FILENAME android-ndk-r12b-linux-x86_64.zip +ENV ANDROID_NDK_URL https://dl.google.com/android/repository/${ANDROID_NDK_FILENAME} ENV ANDROID_NDK_HOME ${ANDROID_DEV_HOME}/ndk ENV PATH ${PATH}:${ANDROID_NDK_HOME} RUN cd ${ANDROID_DEV_HOME} && \ wget -q ${ANDROID_NDK_URL} && \ - chmod +x ${ANDROID_NDK_FILENAME} && \ - ./${ANDROID_NDK_FILENAME} -o${ANDROID_DEV_HOME} && \ + unzip ${ANDROID_NDK_FILENAME} -d ${ANDROID_DEV_HOME} && \ rm ${ANDROID_NDK_FILENAME} && \ bash -c "ln -s ${ANDROID_DEV_HOME}/android-ndk-* ${ANDROID_NDK_HOME}" diff --git a/tensorflow/tools/ci_build/Dockerfile.cmake b/tensorflow/tools/ci_build/Dockerfile.cmake index 8a28fe6cdf9..9013dc012d9 100644 --- a/tensorflow/tools/ci_build/Dockerfile.cmake +++ b/tensorflow/tools/ci_build/Dockerfile.cmake @@ -1,3 +1,17 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== FROM ubuntu:16.04 MAINTAINER Shanqing Cai @@ -7,9 +21,10 @@ COPY install/*.sh /install/ RUN /install/install_bootstrap_deb_packages.sh RUN /install/install_deb_packages.sh +RUN apt-get update +RUN apt-get install -y --no-install-recommends python-pip RUN pip install --upgrade numpy # Install golang RUN add-apt-repository -y ppa:ubuntu-lxc/lxd-stable -RUN apt-get update RUN apt-get install -y golang diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu b/tensorflow/tools/ci_build/Dockerfile.cpu index 86ead3dd4df..206108930a1 100644 --- a/tensorflow/tools/ci_build/Dockerfile.cpu +++ b/tensorflow/tools/ci_build/Dockerfile.cpu @@ -6,7 +6,6 @@ MAINTAINER Jan Prach COPY install/*.sh /install/ RUN /install/install_bootstrap_deb_packages.sh RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - add-apt-repository -y ppa:mc3man/trusty-media && \ add-apt-repository -y ppa:george-edison55/cmake-3.x RUN /install/install_deb_packages.sh RUN /install/install_pip_packages.sh @@ -14,7 +13,7 @@ RUN /install/install_bazel.sh RUN /install/install_proto3.sh RUN /install/install_buildifier.sh RUN /install/install_auditwheel.sh +RUN /install/install_golang.sh -# Set up bazelrc. -COPY install/.bazelrc /root/.bazelrc -ENV BAZELRC /root/.bazelrc +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc diff --git a/tensorflow/tools/ci_build/Dockerfile.debian.jessie.cpu b/tensorflow/tools/ci_build/Dockerfile.debian.jessie.cpu index fa74320b1e5..b914f51918c 100644 --- a/tensorflow/tools/ci_build/Dockerfile.debian.jessie.cpu +++ b/tensorflow/tools/ci_build/Dockerfile.debian.jessie.cpu @@ -5,14 +5,22 @@ MAINTAINER Jan Prach # Copy and run the install scripts. COPY install/*.sh /install/ RUN /install/install_bootstrap_deb_packages.sh -RUN echo "deb http://http.debian.net/debian jessie-backports main" | tee -a /etc/apt/sources.list +RUN echo "deb http://http.debian.net/debian jessie-backports main" | \ + tee -a /etc/apt/sources.list +# Workaround bug in Jessie backport repository deb packages +# http://serverfault.com/questions/830636/cannot-install-openjdk-8-jre-headless-on-debian-jessie +RUN apt-get update && \ + apt-get install -y --no-install-recommends -t jessie-backports \ + openjdk-8-jre-headless ca-certificates-java && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* RUN /install/install_deb_packages.sh RUN /install/install_pip_packages.sh RUN /install/install_bazel.sh +RUN /install/install_golang.sh # Fix a virtualenv install issue specific to Debian Jessie. RUN pip install --upgrade virtualenv -# Set up bazelrc. -COPY install/.bazelrc /root/.bazelrc -ENV BAZELRC /root/.bazelrc +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu b/tensorflow/tools/ci_build/Dockerfile.gpu index 4d7f6ef95da..5d18295f68d 100644 --- a/tensorflow/tools/ci_build/Dockerfile.gpu +++ b/tensorflow/tools/ci_build/Dockerfile.gpu @@ -1,20 +1,25 @@ -FROM nvidia/cuda:8.0-cudnn5-devel +FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu14.04 MAINTAINER Jan Prach +# In the Ubuntu 14.04 images, cudnn is placed in system paths. Move them to +# /usr/local/cuda +RUN cp -P /usr/include/cudnn.h /usr/local/cuda/include +RUN cp -P /usr/lib/x86_64-linux-gnu/libcudnn* /usr/local/cuda/lib64 + # Copy and run the install scripts. COPY install/*.sh /install/ +ARG DEBIAN_FRONTEND=noninteractive RUN /install/install_bootstrap_deb_packages.sh RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - add-apt-repository -y ppa:mc3man/trusty-media && \ add-apt-repository -y ppa:george-edison55/cmake-3.x RUN /install/install_deb_packages.sh RUN /install/install_pip_packages.sh RUN /install/install_bazel.sh +RUN /install/install_golang.sh -# Set up bazelrc. -COPY install/.bazelrc /root/.bazelrc -ENV BAZELRC /root/.bazelrc +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH # Configure the build for our CUDA configuration. diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu_clang b/tensorflow/tools/ci_build/Dockerfile.gpu_clang new file mode 100644 index 00000000000..c4342d17f5f --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.gpu_clang @@ -0,0 +1,36 @@ +FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu14.04 + +MAINTAINER Ilya Biryukov + +# In the Ubuntu 14.04 images, cudnn is placed in system paths. Move them to +# /usr/local/cuda +RUN cp /usr/include/cudnn.h /usr/local/cuda/include +RUN cp /usr/lib/x86_64-linux-gnu/libcudnn* /usr/local/cuda/lib64 + +# Copy and run the install scripts. +COPY install/*.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa + +# LLVM requires cmake version 3.4.3, but ppa:george-edison55/cmake-3.x only +# provides version 3.2.2. +# So we skip it in `install_deb_packages.sh`, and later install it from +# https://cmake.org in `install_cmake_for_clang.sh`. +RUN /install/install_deb_packages.sh --without_cmake +RUN /install/install_pip_packages.sh +RUN /install/install_bazel.sh +RUN /install/install_golang.sh + +# Install cmake and build clang +RUN /install/install_cmake_for_clang.sh +RUN /install/build_and_install_clang.sh + +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc +ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH + +# Configure the build for our CUDA configuration. +ENV TF_NEED_CUDA 1 +ENV TF_CUDA_CLANG 1 +ENV CLANG_CUDA_COMPILER_PATH /usr/local/bin/clang +ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0 diff --git a/tensorflow/tools/ci_build/Dockerfile.hadoop b/tensorflow/tools/ci_build/Dockerfile.hadoop index 8a97a4b466c..489493c26e4 100644 --- a/tensorflow/tools/ci_build/Dockerfile.hadoop +++ b/tensorflow/tools/ci_build/Dockerfile.hadoop @@ -6,7 +6,6 @@ MAINTAINER Jonathan Hseu COPY install/*.sh /install/ RUN /install/install_bootstrap_deb_packages.sh RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - add-apt-repository -y ppa:mc3man/trusty-media && \ add-apt-repository -y ppa:george-edison55/cmake-3.x RUN /install/install_deb_packages.sh RUN /install/install_pip_packages.sh @@ -15,6 +14,5 @@ RUN /install/install_proto3.sh RUN /install/install_buildifier.sh RUN /install/install_hadoop.sh -# Set up bazelrc. -COPY install/.bazelrc /root/.bazelrc -ENV BAZELRC /root/.bazelrc +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc diff --git a/tensorflow/tools/ci_build/Dockerfile.tensorboard b/tensorflow/tools/ci_build/Dockerfile.tensorboard index 0ce2ab3aa54..9795872e2c4 100644 --- a/tensorflow/tools/ci_build/Dockerfile.tensorboard +++ b/tensorflow/tools/ci_build/Dockerfile.tensorboard @@ -6,7 +6,6 @@ MAINTAINER Jan Prach COPY install/*.sh /install/ RUN /install/install_bootstrap_deb_packages.sh RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - add-apt-repository -y ppa:mc3man/trusty-media && \ add-apt-repository -y ppa:george-edison55/cmake-3.x RUN /install/install_deb_packages.sh RUN /install/install_tensorboard_packages.sh diff --git a/tensorflow/tools/ci_build/README.md b/tensorflow/tools/ci_build/README.md index 5c90fceaf70..ad83669950f 100644 --- a/tensorflow/tools/ci_build/README.md +++ b/tensorflow/tools/ci_build/README.md @@ -13,28 +13,32 @@ run continuous integration [ci.tensorflow.org](https://ci.tensorflow.org). You can run all the jobs **without docker** if you are on mac or on linux and you just don't want docker. Just install all the dependencies from - [os_setup.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/get_started/os_setup.md). + [Installing TensorFlow](https://www.tensorflow.org/install/). Then run any of the one liners below without the `tensorflow/tools/ci_build/ci_build.sh` in them. 2. Clone tensorflow repository. ```bash -git clone https://github.com/tensorflow/tensorflow.git -``` + git clone https://github.com/tensorflow/tensorflow.git + ``` 3. Go to tensorflow directory ```bash -cd tensorflow -``` + cd tensorflow + ``` 4. Build what you want, for example ```bash -tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... -``` - + tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... + ``` + If you are using the Docker image on Windows or OS X, the Docker VM's default + memory limit may be too low to build TensorFlow. This can result in + strange-looking errors, e.g. the compilation may fail with `gcc: internal + compiler error: Killed (program cc1plus)`. Try increasing the memory limit in + the Docker preferences. ## Jobs @@ -53,10 +57,10 @@ tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... tensorflow/tools/ci_build/ci_build.sh GPU bazel test -c opt --config=cuda //tensorflow/... # build pip with gpu support -tensorflow/tools/ci_build/ci_build.sh GPU tensorflow/tools/ci_build/builds/pip.sh GPU +tensorflow/tools/ci_build/ci_build.sh GPU tensorflow/tools/ci_build/builds/pip.sh GPU -c opt --config=cuda # build and run gpu tests using python 3 -CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3" tensorflow/tools/ci_build/ci_build.sh GPU tensorflow/tools/ci_build/builds/pip.sh GPU +CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3" tensorflow/tools/ci_build/ci_build.sh GPU tensorflow/tools/ci_build/builds/pip.sh GPU -c opt --config=cuda # build android example app tensorflow/tools/ci_build/ci_build.sh ANDROID tensorflow/tools/ci_build/builds/android.sh diff --git a/tensorflow/tools/ci_build/builds/android_full.sh b/tensorflow/tools/ci_build/builds/android_full.sh index 3282efa8d2d..63250e0a4da 100755 --- a/tensorflow/tools/ci_build/builds/android_full.sh +++ b/tensorflow/tools/ci_build/builds/android_full.sh @@ -31,8 +31,10 @@ configure_android_workspace CPUS=armeabi-v7a,arm64-v8a,x86,x86_64 OUT_DIR="$(pwd)/out/" +AAR_LIB_TMP="$(pwd)/aar_libs" rm -rf ${OUT_DIR} +rm -rf ${AAR_LIB_TMP} # Build all relevant native libraries for each architecture. for CPU in ${CPUS//,/ } @@ -50,6 +52,9 @@ do copy_lib bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so copy_lib bazel-bin/tensorflow/examples/android/libtensorflow_demo.so copy_lib bazel-bin/tensorflow/tools/benchmark/benchmark_model + + mkdir -p ${AAR_LIB_TMP}/jni/${CPU} + cp bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so ${AAR_LIB_TMP}/jni/${CPU} done # Build Jar and also demo containing native libs for all architectures. @@ -60,15 +65,41 @@ echo "========== Building TensorFlow Android Jar and Demo ==========" bazel --bazelrc=/dev/null build -c opt --fat_apk_cpu=${CPUS} \ --spawn_strategy=sandboxed --genrule_strategy=sandboxed \ //tensorflow/contrib/android:android_tensorflow_inference_java \ + //tensorflow/contrib/android:android_tensorflow_inference_java.aar \ //tensorflow/examples/android:tensorflow_demo -echo "Copying demo and Jar to ${OUT_DIR}" +echo "Copying demo, AAR and Jar to ${OUT_DIR}" cp bazel-bin/tensorflow/examples/android/tensorflow_demo.apk \ bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar ${OUT_DIR} -echo "========== Makefile Build Test ==========" +cp bazel-bin/tensorflow/contrib/android/android_tensorflow_inference_java.aar \ + ${OUT_DIR}/tensorflow.aar + +# TODO(andrewharp): build native libs into AAR directly once +# https://github.com/bazelbuild/bazel/issues/348 is resolved. +echo "Adding native libs to AAR" +chmod +w ${OUT_DIR}/tensorflow.aar +pushd ${AAR_LIB_TMP} +zip -ur ${OUT_DIR}/tensorflow.aar $(find jni -name *.so) +popd +rm -rf ${AAR_LIB_TMP} + # Test Makefile build just to make sure it still works. if [ -z "$NDK_ROOT" ]; then export NDK_ROOT=${ANDROID_NDK_HOME} fi + +echo "========== Benchmark Makefile Build Test ==========" tensorflow/contrib/makefile/build_all_android.sh + +echo "========== Demo Makefile Build Test ==========" +tensorflow/contrib/makefile/build_all_android.sh \ +-s $(pwd)/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in \ +-t "libtensorflow_inference.so libtensorflow_demo.so" + +# Test Makefile build for tensorflow runtime with hexagon. +# -b ... build only, -p ... use prebuilt binaries +# This uses prebuilt binaries for hexagon dependencies because Building +# hexagon binaries from source code requires qualcomm sdk. +echo "========== Hexagon Build Test ==========" +tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh -bp diff --git a/tensorflow/tools/ci_build/builds/builds_common.sh b/tensorflow/tools/ci_build/builds/builds_common.sh index d9a6ce96a58..fd9a14bd698 100644 --- a/tensorflow/tools/ci_build/builds/builds_common.sh +++ b/tensorflow/tools/ci_build/builds/builds_common.sh @@ -230,7 +230,7 @@ android_sdk_repository( android_ndk_repository( name="androidndk", path="${ANDROID_NDK_HOME}", - api_level=21) + api_level=14) EOF fi fi diff --git a/tensorflow/tools/ci_build/builds/configured b/tensorflow/tools/ci_build/builds/configured index f813d6c13f5..25cb51ea7cc 100755 --- a/tensorflow/tools/ci_build/builds/configured +++ b/tensorflow/tools/ci_build/builds/configured @@ -47,6 +47,10 @@ export CI_BUILD_PYTHON="${CI_BUILD_PYTHON:-python}" export PYTHON_BIN_PATH="${PYTHON_BIN_PATH:-$(which ${CI_BUILD_PYTHON})}" if [ "${CONTAINER_TYPE}" == "gpu" ]; then export TF_NEED_CUDA=1 +elif [ "${CONTAINER_TYPE}" == "gpu_clang" ]; then + export TF_NEED_CUDA=1 + export TF_CUDA_CLANG=1 + export CLANG_CUDA_COMPILER_PATH="/usr/local/bin/clang" else export TF_NEED_CUDA=0 fi diff --git a/tensorflow/tools/ci_build/builds/docker_test.sh b/tensorflow/tools/ci_build/builds/docker_test.sh index ee004eb46c2..e337ea4b059 100755 --- a/tensorflow/tools/ci_build/builds/docker_test.sh +++ b/tensorflow/tools/ci_build/builds/docker_test.sh @@ -114,7 +114,7 @@ fi docker run -v ${BASE_DIR}:/tensorflow-src -w /tensorflow-src \ ${GPU_EXTRA_PARAMS} \ "${DOCKER_IMG_TAG}" \ -/bin/bash -c "tensorflow/tools/ci_build/builds/test_installation.sh && "\ +/bin/bash -c "tensorflow/tools/ci_build/builds/run_pip_tests.sh && "\ "tensorflow/tools/ci_build/builds/test_tutorials.sh && "\ "tensorflow/tools/ci_bukld/builds/integration_tests.sh" diff --git a/tensorflow/tools/ci_build/builds/libtensorflow.sh b/tensorflow/tools/ci_build/builds/libtensorflow.sh index ce0c8549573..5052d3626c9 100755 --- a/tensorflow/tools/ci_build/builds/libtensorflow.sh +++ b/tensorflow/tools/ci_build/builds/libtensorflow.sh @@ -14,19 +14,27 @@ # limitations under the License. # ============================================================================== # -# Script to generate a tarball containing the TensorFlow C-library which -# consists of the C API header file and libtensorflow.so. +# Script to generate tarballs: +# (1) The TensorFlow C-library: Containing C API header files and libtensorflow.so +# (2) Native library for the TensorFlow Java API: Containing libtensorflow_jni.so +# And jars: +# (3) Java API .jar +# (4) Java API sources .jar # -# Work in progress but this is a step towards a "binary" distribution of the -# TensorFlow C-library allowing TensorFlow language bindings to be used -# without having to recompile the TensorFlow framework from sources, which -# takes a while and also introduces many other dependencies. +# These binary distributions will allow use of TensorFlow in various languages +# without having to compile the TensorFlow framework from sources, which takes +# a while and also introduces many other dependencies. # # Usage: # - Source this file in another bash script # - Execute build_libtensorflow_tarball SUFFIX # -# Produces: lib_package/libtensorflow${SUFFIX}.tar.gz +# Produces: +# - lib_package/libtensorflow${SUFFIX}.tar.gz +# - lib_package/libtensorflow_jni${SUFFIX}.tar.gz +# - lib_package/libtensorflow.jar +# - lib_package/libtensorflow-src.jar +# - lib_package/libtensorflow_proto.zip # # ASSUMPTIONS: # - build_libtensorflow_tarball is invoked from the root of the git tree. @@ -38,6 +46,10 @@ function build_libtensorflow_tarball() { echo "Must run this from the root of the bazel workspace" exit 1 fi + # Delete any leftovers from previous builds in this workspace. + DIR=lib_package + rm -rf ${DIR} + TARBALL_SUFFIX="${1}" BAZEL="bazel --bazelrc ./tensorflow/tools/ci_build/install/.bazelrc" BAZEL_OPTS="-c opt" @@ -52,11 +64,23 @@ function build_libtensorflow_tarball() { # and https://github.com/bazelbuild/bazel/issues/1580 # have been resolved and the "manual" tags on the BUILD targets # in tensorflow/tools/lib_package/BUILD are removed. - # Till then, must manually run the test. - bazel test ${BAZEL_OPTS} //tensorflow/tools/lib_package:libtensorflow_test + # Till then, must manually run the test since these tests are + # not covered by the continuous integration. + bazel test ${BAZEL_OPTS} \ + //tensorflow/tools/lib_package:libtensorflow_test \ + //tensorflow/tools/lib_package:libtensorflow_java_test + + bazel build ${BAZEL_OPTS} \ + //tensorflow/tools/lib_package:libtensorflow.tar.gz \ + //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz \ + //tensorflow/java:libtensorflow.jar \ + //tensorflow/java:libtensorflow-src.jar \ + //tensorflow/tools/lib_package:libtensorflow_proto.zip - bazel build ${BAZEL_OPTS} //tensorflow/tools/lib_package:libtensorflow.tar.gz - DIR=lib_package mkdir -p ${DIR} cp bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz ${DIR}/libtensorflow${TARBALL_SUFFIX}.tar.gz + cp bazel-bin/tensorflow/tools/lib_package/libtensorflow_jni.tar.gz ${DIR}/libtensorflow_jni${TARBALL_SUFFIX}.tar.gz + cp bazel-bin/tensorflow/java/libtensorflow.jar bazel-bin/tensorflow/java/libtensorflow-src.jar ${DIR} + cp bazel-genfiles/tensorflow/tools/lib_package/libtensorflow_proto.zip ${DIR} + chmod -x ${DIR}/* } diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh index 1e55ad01245..85c712d3c6d 100755 --- a/tensorflow/tools/ci_build/builds/pip.sh +++ b/tensorflow/tools/ci_build/builds/pip.sh @@ -19,8 +19,7 @@ # The PIP installation is done using the --user flag. # # Usage: -# pip.sh CONTAINER_TYPE [--mavx] [--mavx2] -# [--test_tutorials] [--integration_tests] +# pip.sh CONTAINER_TYPE [--test_tutorials] [--integration_tests] [bazel flags] # # When executing the Python unit tests, the script obeys the shell # variables: TF_BUILD_BAZEL_CLEAN, TF_BUILD_INSTALL_EXTRA_PIP_PACKAGES, @@ -30,7 +29,7 @@ # script to perform bazel clean prior to main build and test steps. # # TF_BUILD_INSTALL_EXTRA_PIP_PACKAGES overrides the default extra pip packages -# to be installed in virtualenv before test_installation.sh is called. Multiple +# to be installed in virtualenv before run_pip_tests.sh is called. Multiple # pakcage names are separated with spaces. # # If NO_TEST_ON_INSTALL has any non-empty and non-0 value, the test-on-install @@ -39,8 +38,10 @@ # If NO_TEST_USER_OPS has any non-empty and non-0 value, the testing of user- # defined ops against the installation will be skipped. # -# Use --mavx or --mavx2 to let bazel use --copt=-mavx or --copt=-mavx2 options -# while building the pip package, respectively. +# If NO_TEST_TFDBG_BINARIES has any non-empty and non-0 value, the testing of +# TensorFlow Debugger (tfdbg) binaries and examples will be skipped. +# +# Any flags not listed in the usage above will be passed directly to Bazel. # # If the --test_tutorials flag is set, it will cause the script to run the # tutorial tests (see test_tutorials.sh) after the PIP @@ -49,6 +50,11 @@ # to run. # +# Helper function: Strip leading and trailing whitespaces +str_strip () { + echo -e "$1" | sed -e 's/^[[:space:]]*//' -e 's/[[:space:]]*$//' +} + # Fixed naming patterns for wheel (.whl) files given different python versions if [[ $(uname) == "Linux" ]]; then declare -A WHL_TAGS @@ -66,32 +72,38 @@ source "${SCRIPT_DIR}/builds_common.sh" # Get the command line arguments CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' ) +shift -if [[ ! -z "${TF_BUILD_BAZEL_CLEAN}" ]] && \ +if [[ -n "${TF_BUILD_BAZEL_CLEAN}" ]] && \ [[ "${TF_BUILD_BAZEL_CLEAN}" != "0" ]]; then echo "TF_BUILD_BAZEL_CLEAN=${TF_BUILD_BAZEL_CLEAN}: Performing 'bazel clean'" bazel clean fi DO_TEST_USER_OPS=1 -if [[ ! -z "${NO_TEST_USER_OPS}" ]] && \ +if [[ -n "${NO_TEST_USER_OPS}" ]] && \ [[ "${NO_TEST_USER_OPS}" != "0" ]]; then echo "NO_TEST_USER_OPS=${NO_TEST_USER_OPS}: Will skip testing of user ops" DO_TEST_USER_OPS=0 fi +DO_TEST_TFDBG_BINARIES=1 +if [[ -n "${NO_TEST_TFDBG_BINARIES}" ]] && \ + [[ "${NO_TEST_TFDBG_BINARIES}" != "0" ]]; then + echo "NO_TEST_TFDBG_BINARIES=${NO_TEST_TFDBG_BINARIES}: Will skip testing of tfdbg binaries" + DO_TEST_TFDBG_BINARIES=0 +fi + DO_TEST_TUTORIALS=0 DO_INTEGRATION_TESTS=0 -MAVX_FLAG="" +BAZEL_FLAGS="" while true; do if [[ "${1}" == "--test_tutorials" ]]; then DO_TEST_TUTORIALS=1 elif [[ "${1}" == "--integration_tests" ]]; then DO_INTEGRATION_TESTS=1 - elif [[ "${1}" == "--mavx" ]]; then - MAVX_FLAG="--copt=-mavx" - elif [[ "${1}" == "--mavx2" ]]; then - MAVX_FLAG="--copt=-mavx2" + else + BAZEL_FLAGS="${BAZEL_FLAGS} ${1}" fi shift @@ -100,18 +112,18 @@ while true; do fi done -if [[ ! -z "${MAVX_FLAG}" ]]; then - echo "Using MAVX flag: ${MAVX_FLAG}" -fi +BAZEL_FLAGS=$(str_strip "${BAZEL_FLAGS}") + +echo "Using Bazel flags: ${BAZEL_FLAGS}" PIP_BUILD_TARGET="//tensorflow/tools/pip_package:build_pip_package" GPU_FLAG="" if [[ ${CONTAINER_TYPE} == "cpu" ]] || \ [[ ${CONTAINER_TYPE} == "debian.jessie.cpu" ]]; then - bazel build -c opt ${MAVX_FLAG} ${PIP_BUILD_TARGET} || \ + bazel build ${BAZEL_FLAGS} ${PIP_BUILD_TARGET} || \ die "Build failed." elif [[ ${CONTAINER_TYPE} == "gpu" ]]; then - bazel build -c opt --config=cuda ${MAVX_FLAG} ${PIP_BUILD_TARGET} || \ + bazel build ${BAZEL_FLAGS} ${PIP_BUILD_TARGET} || \ die "Build failed." GPU_FLAG="--gpu" else @@ -125,7 +137,7 @@ fi # If still in a virtualenv, deactivate it first -if [[ ! -z "$(which deactivate)" ]]; then +if [[ -n "$(which deactivate)" ]]; then echo "It appears that we are already in a virtualenv. Deactivating..." deactivate || die "FAILED: Unable to deactivate from existing virtualenv" fi @@ -163,6 +175,11 @@ if [[ $(echo ${WHL_PATH} | wc -w) -ne 1 ]]; then "directory: ${PIP_WHL_DIR}" fi +# Print the size of the PIP wheel file. +echo +echo "Size of the PIP wheel file built: $(ls -l ${WHL_PATH} | awk '{print $5}')" +echo + # Rename the whl file properly so it will have the python # version tags and platform tags that won't cause pip install issues. if [[ $(uname) == "Linux" ]]; then @@ -174,6 +191,8 @@ elif [[ $(uname) == "Darwin" ]]; then PY_TAGS="py2-none" elif [[ ${PY_MAJOR_MINOR_VER} == "3.5" ]]; then PY_TAGS="py3-none" + elif [[ ${PY_MAJOR_MINOR_VER} == "3.6" ]]; then + PY_TAGS="py3-none" fi PLATFORM_TAG="any" fi @@ -181,19 +200,22 @@ fi WHL_DIR=$(dirname "${WHL_PATH}") WHL_BASE_NAME=$(basename "${WHL_PATH}") -if [[ ! -z "${PY_TAGS}" ]]; then +if [[ -n "${PY_TAGS}" ]]; then NEW_WHL_BASE_NAME=$(echo ${WHL_BASE_NAME} | cut -d \- -f 1)-\ $(echo ${WHL_BASE_NAME} | cut -d \- -f 2)-${PY_TAGS}-${PLATFORM_TAG}.whl if [[ ! -f "${WHL_DIR}/${NEW_WHL_BASE_NAME}" ]]; then - cp "${WHL_DIR}/${WHL_BASE_NAME}" "${WHL_DIR}/${NEW_WHL_BASE_NAME}" && \ - echo "Copied wheel file: ${WHL_BASE_NAME} --> ${NEW_WHL_BASE_NAME}" || \ + if cp "${WHL_DIR}/${WHL_BASE_NAME}" "${WHL_DIR}/${NEW_WHL_BASE_NAME}" + then + echo "Copied wheel file: ${WHL_BASE_NAME} --> ${NEW_WHL_BASE_NAME}" + else die "ERROR: Failed to copy wheel file to ${NEW_WHL_BASE_NAME}" + fi fi fi if [[ $(uname) == "Linux" ]]; then - AUDITED_WHL_NAME="${WHL_DIR}/$(echo ${WHL_BASE_NAME} | sed "s/linux/manylinux1/")" + AUDITED_WHL_NAME="${WHL_DIR}/$(echo ${WHL_BASE_NAME//linux/manylinux1})" # Repair the wheels for cpu manylinux1 if [[ ${CONTAINER_TYPE} == "cpu" ]]; then @@ -221,14 +243,20 @@ echo "Installing pip whl file: ${WHL_PATH}" VENV_DIR="${PIP_TEST_ROOT}/venv" if [[ -d "${VENV_DIR}" ]]; then - rm -rf "${VENV_DIR}" && \ - echo "Removed existing virtualenv directory: ${VENV_DIR}" || \ - die "Failed to remove existing virtualenv directory: ${VENV_DIR}" + if rm -rf "${VENV_DIR}" + then + echo "Removed existing virtualenv directory: ${VENV_DIR}" + else + die "Failed to remove existing virtualenv directory: ${VENV_DIR}" + fi fi -mkdir -p ${VENV_DIR} && \ - echo "Created virtualenv directory: ${VENV_DIR}" || \ - die "FAILED to create virtualenv directory: ${VENV_DIR}" +if mkdir -p ${VENV_DIR} +then + echo "Created virtualenv directory: ${VENV_DIR}" +else + die "FAILED to create virtualenv directory: ${VENV_DIR}" +fi # Verify that virtualenv exists if [[ -z $(which virtualenv) ]]; then @@ -250,7 +278,7 @@ pip install --upgrade pip==8.1.2 # Force tensorflow reinstallation. Otherwise it may not get installed from # last build if it had the same version number as previous build. -PIP_FLAGS="--upgrade --force-reinstall --no-deps" +PIP_FLAGS="--upgrade --force-reinstall" pip install -v ${PIP_FLAGS} ${WHL_PATH} || \ die "pip install (forcing to reinstall tensorflow) FAILED" echo "Successfully installed pip package ${WHL_PATH}" @@ -263,13 +291,13 @@ for PACKAGE in ${INSTALL_EXTRA_PIP_PACKAGES}; do die "pip install ${PACKAGE} FAILED" done -if [[ ! -z "${NO_TEST_ON_INSTALL}" ]] && +if [[ -n "${NO_TEST_ON_INSTALL}" ]] && [[ "${NO_TEST_ON_INSTALL}" != "0" ]]; then echo "NO_TEST_ON_INSTALL=${NO_TEST_ON_INSTALL}:" echo " Skipping ALL Python unit tests on install" else - # Call test_installation.sh to perform test-on-install - "${SCRIPT_DIR}/test_installation.sh" --virtualenv ${GPU_FLAG} ${MAC_FLAG} || + # Call run_pip_tests.sh to perform test-on-install + "${SCRIPT_DIR}/run_pip_tests.sh" --virtualenv ${GPU_FLAG} ${MAC_FLAG} || die "PIP tests-on-install FAILED" fi @@ -279,6 +307,24 @@ if [[ "${DO_TEST_USER_OPS}" == "1" ]]; then die "PIP user-op tests-on-install FAILED" fi +# Test TensorFlow Debugger (tfdbg) examples. +if [[ "${DO_TEST_TFDBG_BINARIES}" == "1" ]]; then + echo + echo "Testing TensorFlow Debugger (tfdbg) binaries" + echo + + # cd to a temporary directory to avoid picking up Python files in the source + # tree. + TMP_DIR=$(mktemp -d) + pushd "${TMP_DIR}" + + "${SCRIPT_DIR}/../../../python/debug/examples/examples_test.sh" \ + --virtualenv || \ + die "PIP tests-on-install of tfdbg binaries FAILED" + + popd +fi + # Optional: Run the tutorial tests if [[ "${DO_TEST_TUTORIALS}" == "1" ]]; then "${SCRIPT_DIR}/test_tutorials.sh" --virtualenv || \ diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh new file mode 100755 index 00000000000..8e364f7ffb7 --- /dev/null +++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== +# +# Run the python unit tests from the source code on the pip installation. +# +# Usage: +# run_pip_tests.sh [--virtualenv] [--gpu] [--mac] +# +# If the flag --virtualenv is set, the script will use "python" as the Python +# binary path. Otherwise, it will use tools/python_bin_path.sh to determine +# the Python binary path. +# +# The --gpu flag informs the script that this is a GPU build, so that the +# appropriate test blacklists can be applied accordingly. +# +# The --mac flag informs the script that this is running on mac. Mac does not +# have flock, so we should skip using parallel_gpu_execute on mac. +# +# TF_BUILD_APPEND_ARGUMENTS: +# Additional command line arguments for the bazel, +# pip.sh or android.sh command + +# Script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/builds_common.sh" + +# Process input arguments +IS_VIRTUALENV=0 +IS_GPU=0 +IS_MAC=0 +while true; do + if [[ "$1" == "--virtualenv" ]]; then + IS_VIRTUALENV=1 + elif [[ "$1" == "--gpu" ]]; then + IS_GPU=1 + elif [[ "$1" == "--mac" ]]; then + IS_MAC=1 + fi + shift + + if [[ -z "$1" ]]; then + break + fi +done + +TF_GPU_COUNT=${TF_GPU_COUNT:-8} + +# PIP tests should have a "different" path. Different than the one we place +# virtualenv, because we are deleting and recreating it here. +PIP_TEST_PREFIX=bazel_pip +PIP_TEST_ROOT=$(pwd)/${PIP_TEST_PREFIX} +rm -rf $PIP_TEST_ROOT +mkdir -p $PIP_TEST_ROOT +ln -s $(pwd)/tensorflow ${PIP_TEST_ROOT}/tensorflow + +# Do not run tests with "no_pip" tag. If running GPU tests, also do not run +# tests with no_pip_gpu tag. +PIP_TEST_FILTER_TAG="-no_pip" +if [[ ${IS_GPU} == "1" ]]; then + PIP_TEST_FILTER_TAG="-no_pip_gpu,${PIP_TEST_FILTER_TAG}" +fi + +# Bazel flags we need for all tests: +# define=no_tensorflow_py_deps=true, to skip all test dependencies. +# test_lang_filters=py only py tests for pip package testing +# TF_BUILD_APPEND_ARGUMENTS any user supplied args. +BAZEL_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py \ + --build_tests_only -k --test_tag_filters=${PIP_TEST_FILTER_TAG} \ + --test_timeout 300,450,1200,3600 ${TF_BUILD_APPEND_ARGUMENTS}" + +BAZEL_TEST_TARGETS="//${PIP_TEST_PREFIX}/tensorflow/contrib/... \ + //${PIP_TEST_PREFIX}/tensorflow/python/... \ + -//${PIP_TEST_PREFIX}/tensorflow/contrib/tensorboard/..." + +# Clean the bazel cache +bazel clean + +# Run configure again, we might be using a different python path, due to +# virtualenv. +export TF_NEED_GCP=0 +export TF_NEED_HDFS=0 +export TF_ENABLE_XLA=${TF_BUILD_ENABLE_XLA:-0} + +# Obtain the path to Python binary +if [[ ${IS_VIRTUALENV} == "1" ]]; then + PYTHON_BIN_PATH="$(which python)" +else + source tools/python_bin_path.sh + # Assume: PYTHON_BIN_PATH is exported by the script above +fi + +export TF_NEED_CUDA=$IS_GPU +yes "" | ./configure + +# Figure out how many concurrent tests we can run and do run the tests. +BAZEL_PARALLEL_TEST_FLAGS="" +if [[ $IS_GPU == 1 ]]; then + # Number of test threads is the number of GPU cards available. + if [[ $IS_MAC == 1 ]]; then + BAZEL_PARALLEL_TEST_FLAGS="--local_test_jobs=1" + else + PAR_TEST_JOBS=$TF_GPU_COUNT + BAZEL_PARALLEL_TEST_FLAGS="--local_test_jobs=${TF_GPU_COUNT} \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute" + fi +else + # Number of test threads is the number of physical CPUs. + if [[ $IS_MAC == 1 ]]; then + BAZEL_PARALLEL_TEST_FLAGS="--local_test_jobs=$(sysctl -n hw.ncpu)" + else + BAZEL_PARALLEL_TEST_FLAGS="--local_test_jobs=$(grep -c ^processor /proc/cpuinfo)" + fi +fi + +# Actually run the tests. +bazel test ${BAZEL_FLAGS} ${BAZEL_PARALLEL_TEST_FLAGS} -- \ + ${BAZEL_TEST_TARGETS} diff --git a/tensorflow/tools/ci_build/builds/test_installation.sh b/tensorflow/tools/ci_build/builds/test_installation.sh deleted file mode 100755 index eb64fbcf185..00000000000 --- a/tensorflow/tools/ci_build/builds/test_installation.sh +++ /dev/null @@ -1,603 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# -# Build the Python PIP installation package for TensorFlow -# and run the Python unit tests from the source code on the installation -# -# Usage: -# test_installation.sh [--virtualenv] [--gpu] [--mac] -# -# If the flag --virtualenv is set, the script will use "python" as the Python -# binary path. Otherwise, it will use tools/python_bin_path.sh to determine -# the Python binary path. -# -# The --gpu flag informs the script that this is a GPU build, so that the -# appropriate test blacklists can be applied accordingly. -# -# The --mac flag informs the script that this is running on mac. Mac does not -# have flock, so we should skip using parallel_gpu_execute on mac. -# -# When executing the Python unit tests, the script obeys the shell -# variables: PY_TEST_WHITELIST, PY_TEST_BLACKLIST, PY_TEST_GPU_BLACKLIST, -# -# To select only a subset of the Python tests to run, set the environment -# variable PY_TEST_WHITELIST, e.g., -# PY_TEST_WHITELIST="tensorflow/python/kernel_tests/shape_ops_test.py" -# Separate the tests with a colon (:). Leave this environment variable empty -# to disable the whitelist. -# -# You can also ignore a set of the tests by using the environment variable -# PY_TEST_BLACKLIST. For example, you can include in PY_TEST_BLACKLIST the -# tests that depend on Python modules in TensorFlow source that are not -# exported publicly. -# -# In addition, you can put blacklist for only GPU build inthe environment -# variable PY_TEST_GPU_BLACKLIST. -# -# TF_BUILD_BAZEL_CLEAN, if set to any non-empty and non-0 value, directs the -# script to perform bazel clean prior to main build and test steps. -# -# TF_GPU_COUNT, Set the number of GPUs in the system. We run only this many -# concurrent tests when running GPU tests. -# -# TF_BUILD_EXTRA_EXCLUSIVE_INSTALL_TESTS, add to the default list of -# Python unit tests to run in exclusive mode (i.e., not concurrently with -# other tests), separated with colons -# -# TF_BUILD_FILTER_INSTALL_TESTS_BY_TAG: If set to a non-empty string -# (e.g., "local"), will filter the Python install-tests by that string as -# bazel tags. Multiple filter tags can be used. Both the inclusive filtering -# mode and the exclusive filtering mode can be used. For example: -# -# TF_BUILD_FILTER_INSTALL_TESTS_BY_TAG="local,-manual" -# -# will let the script run the Python unit tests that have the tag "local" -# and do not have the tag "manual". The "-" marks the exclusive filtering -# mode. The inclusive mode is the default. Use commas to separate the tags. -# -# If the environmental variable NO_TEST_ON_INSTALL is set to any non-empty -# value, the script will exit after the pip install step. - -# ============================================================================= -# Test blacklist: General -# -# tensorflow/python/framework/ops_test.py -# depends on depends on "test_ops", which is defined in a C++ file wrapped as -# a .py file through the Bazel rule “tf_gen_ops_wrapper_py”. -# tensorflow/util/protobuf/compare_test.py: -# depends on compare_test_pb2 defined outside Python -# tensorflow/python/framework/device_test.py: -# depends on CheckValid() and ToString(), both defined externally -# tensorflow/python/framework/file_system_test.py: -# depends on having the .so which is not shipped in the pip package. -# tensorflow/contrib/quantization/*: -# These depend on an .so mechanism that's not shipped in the pip package. -# tensorflow/python/platform/default/*_test.py: -# These are obsolete and replaced by corresponding files in python/platform. -# They will be removed in the future. - -PY_TEST_BLACKLIST="${PY_TEST_BLACKLIST}:"\ -"tensorflow/python/framework/ops_test.py:"\ -"tensorflow/python/util/protobuf/compare_test.py:"\ -"tensorflow/python/framework/device_test.py:"\ -"tensorflow/python/framework/file_system_test.py:"\ -"tensorflow/contrib/quantization/python/dequantize_op_test.py:"\ -"tensorflow/contrib/quantization/python/quantized_conv_ops_test.py:"\ -"tensorflow/contrib/quantization/tools/quantize_graph_test.py:"\ -"tensorflow/contrib/session_bundle/bundle_shim_test.py:"\ -"tensorflow/contrib/session_bundle/exporter_test.py:"\ -"tensorflow/contrib/session_bundle/session_bundle_test.py:"\ -"tensorflow/python/platform/default/_resource_loader_test.py:"\ -"tensorflow/python/platform/default/flags_test.py:"\ -"tensorflow/python/platform/default/logging_test.py:"\ -"tensorflow/python/saved_model/saved_model_test.py:"\ -"tensorflow/contrib/learn/nonlinear_test.py:"\ -"tensorflow/contrib/distributions/python/kernel_tests/conditional_distribution_test.py:"\ -"tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py:" - -# Test blacklist: GPU-only -PY_TEST_GPU_BLACKLIST="${PY_TEST_GPU_BLACKLIST}:"\ -"tensorflow/python/client/session_test.py:"\ -"tensorflow/python/framework/function_test.py:"\ -"tensorflow/contrib/integrate/python/ops/odes_test.py:"\ -"tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py" - -# Tests that should be run in the exclusive mode (i.e., not parallel with -# other tests) -PY_TEST_EXCLUSIVE_LIST="" - -# Append custom list of exclusive tests -if [[ ! -z "${TF_BUILD_EXTRA_EXCLUSIVE_INSTALL_TESTS}" ]]; then - PY_TEST_EXCLUSIVE_LIST="${PY_TEST_EXCLUSIVE_LIST}:"\ -"${TF_BUILD_EXTRA_EXCLUSIVE_INSTALL_TESTS}" -fi - -# ============================================================================= - -echo "PY_TEST_WHITELIST: ${PY_TEST_WHITELIST}" -echo "PY_TEST_BLACKLIST: ${PY_TEST_BLACKLIST}" -echo "PY_TEST_GPU_BLACKLIST: ${PY_TEST_GPU_BLACKLIST}" - - -# Script directory -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -source "${SCRIPT_DIR}/builds_common.sh" - -TF_GPU_COUNT=${TF_GPU_COUNT:-8} - -# Process input arguments -IS_VIRTUALENV=0 -IS_GPU=0 -IS_MAC=0 -while true; do - if [[ "$1" == "--virtualenv" ]]; then - IS_VIRTUALENV=1 - elif [[ "$1" == "--gpu" ]]; then - IS_GPU=1 - elif [[ "$1" == "--mac" ]]; then - IS_MAC=1 - fi - shift - - if [[ -z "$1" ]]; then - break - fi -done - -# Obtain the path to Python binary -if [[ ${IS_VIRTUALENV} == "1" ]]; then - PYTHON_BIN_PATH="$(which python)" -else - source tools/python_bin_path.sh - # Assume: PYTHON_BIN_PATH is exported by the script above -fi - -# Obtain the path to head/ghead binary (for log file printing) -HEAD_BIN="ghead" -if [[ -z $(which "${HEAD_BIN}") ]]; then - # This is not Mac (which uses coreutils/ghead), use head. - HEAD_BIN="head" - if [[ -z $(which "${HEAD_BIN}") ]]; then - die "Unable to obtain path to head or ghead" - fi -fi - -if [[ -z "${PYTHON_BIN_PATH}" ]]; then - die "PYTHON_BIN_PATH was not provided. If this is not virtualenv, "\ -"did you run configure?" -fi - -# Append GPU-only test blacklist -if [[ ${IS_GPU} == "1" ]]; then - PY_TEST_BLACKLIST="${PY_TEST_BLACKLIST}:${PY_TEST_GPU_BLACKLIST}" -fi - -# Determine the major and minor versions of Python being used (e.g., 2.7) -# This info will be useful for determining the directory of the local pip -# installation of Python -PY_MAJOR_MINOR_VER=$(${PYTHON_BIN_PATH} -V 2>&1 | awk '{print $NF}' | cut -d. -f-2) - -echo "Python binary path to be used in PIP install-test: ${PYTHON_BIN_PATH} "\ -"(Major.Minor version: ${PY_MAJOR_MINOR_VER})" - -# Avoid permission issues outside container -umask 000 - -# Directory from which the unit-test files will be run -PY_TEST_DIR_REL="pip_test/tests" -PY_TEST_DIR=$(realpath ${PY_TEST_DIR_REL}) # Get absolute path -rm -rf ${PY_TEST_DIR} && mkdir -p ${PY_TEST_DIR} - -# Create test log directory -PY_TEST_LOG_DIR_REL=${PY_TEST_DIR_REL}/logs -PY_TEST_LOG_DIR=$(realpath ${PY_TEST_LOG_DIR_REL}) # Absolute path - -mkdir ${PY_TEST_LOG_DIR} - -# Copy source files that are required by the tests but are not included in the -# PIP package - -# Look for local Python library directory -# pushd/popd avoids importing TensorFlow from the source directory. -pushd /tmp > /dev/null -TF_INSTALL_PATH=$(dirname \ - $("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; print(tf.__file__)")) -popd > /dev/null - -if [[ -z ${TF_INSTALL_PATH} ]]; then - die "Failed to find path where TensorFlow is installed." -else - echo "Found TensorFlow install path: ${TF_INSTALL_PATH}" -fi - -echo "Copying some source directories required by Python unit tests but "\ -"not included in install to TensorFlow install path: ${TF_INSTALL_PATH}" - -# Files for tensorflow.python.tools -rm -rf ${TF_INSTALL_PATH}/python/tools -cp -r tensorflow/python/tools \ - ${TF_INSTALL_PATH}/python/tools -touch ${TF_INSTALL_PATH}/python/tools/__init__.py # Make module visible - -# Files for tensorflow.examples -rm -rf ${TF_INSTALL_PATH}/examples/image_retraining -mkdir -p ${TF_INSTALL_PATH}/examples/image_retraining -cp -r tensorflow/examples/image_retraining/retrain.py \ - ${TF_INSTALL_PATH}/examples/image_retraining/retrain.py -touch ${TF_INSTALL_PATH}/examples/__init__.py -touch ${TF_INSTALL_PATH}/examples/image_retraining/__init__.py - -echo "Copying additional files required by tests to working directory "\ -"for test: ${PY_TEST_DIR}" - -# Image files required by some tests, e.g., images_ops_test.py - -mkdir -p ${PY_TEST_DIR}/tensorflow/core/lib -rm -rf ${PY_TEST_DIR}/tensorflow/core/lib/jpeg -cp -r tensorflow/core/lib/jpeg ${PY_TEST_DIR}/tensorflow/core/lib -rm -rf ${PY_TEST_DIR}/tensorflow/core/lib/png -cp -r tensorflow/core/lib/png ${PY_TEST_DIR}/tensorflow/core/lib -rm -rf ${PY_TEST_DIR}/tensorflow/core/lib/gif -cp -r tensorflow/core/lib/gif ${PY_TEST_DIR}/tensorflow/core/lib - -# Copy test data from tensorflow/contrib/ffmpeg - -mkdir -p ${PY_TEST_DIR}/tensorflow/contrib/ffmpeg -rm -rf ${PY_TEST_DIR}/tensorflow/contrib/ffmpeg/testdata -cp -r tensorflow/contrib/ffmpeg/testdata ${PY_TEST_DIR} - -# Run tests -DIR0=$(pwd) -ALL_PY_TESTS_0=$(find tensorflow/{contrib,examples,python,tensorboard} \ - -type f \( -name "*_test.py" -o -name "test_*.py" \) | sort) - - -# Subroutine for filtering test file names by a bazel tag. -filter_tests_by_bazel_tag() { - # Usage: filter_tests_by_bazel_tag (--inclusive | --exclusive) - # - # - # E.g., filter_tests_by_bazel_tag --inclusive "local" - # "dir1/test1.py dir2/test2.py" - # - # Use the flag --inclusive so that only the tests that have the tag will be - # included in the returned string. - # Use the flag --exclusive so that the returned string will consist of only - # the tests that do not have the tag. - # INPUT_TESTS are the name of the input Python unit test files, seperated by - # spaces. - # - # The output string (through stdout) is: OUTPUT_TESTS | DISCARDED_TESTS - # That is: a list of tests that passed the filter, followed by " | ", - # followed by a list of tests that are discarded - - FILTER_MODE=$1 - TAG=$2 - INPUT_TESTS=$3 - - # Input sanity checks - if [[ "${FILTER_MODE}" != "--inclusive" ]] && - [[ "${FILTER_MODE}" != "--exclusive" ]]; then - echo "ERROR: Unrecognized filter mode: ${FILTER_MODE}" - exit 1 - fi - if [[ -z "${TAG}" ]]; then - echo "ERROR: Bazal tag is not supplied" - exit 1 - fi - if [[ -z "${INPUT_TESTS}" ]]; then - echo "ERROR: INPUT_TESTS is not supplied" - exit 1 - fi - - # Check bazel on path - if [[ -z $(which bazel) ]]; then - echo "ERROR: bazel is not on path" - exit 1 - fi - - # Get all bazel targets that have the specified tag - BAZEL_TARGETS=\ -$(bazel query "kind(py_test, attr(tags, "${TAG}", //tensorflow/...))" | sort) - - TARGET_ALIASES=":" - for TARGET in ${BAZEL_TARGETS}; do - # Transform, e.g., //tensorflow/python/kernel_tests:xent_op_test --> - # python-xent_op_test - # to be compared with the transformed strings from the Python unit test - # file names. - TARGET_1=$(echo "${TARGET}" | sed "s/:/ /g") - TARGET_PATH_1=$(echo "${TARGET_1}" | sed "s/\/\// /g" | sed "s/\// /g" \ - | awk '{print $2}') - TARGET_BASE_NAME=$(echo "${TARGET_1}" | awk '{print $NF}') - TARGET_ALIAS="${TARGET_PATH_1}-${TARGET_BASE_NAME}" - - TARGET_ALIASES="${TARGET_ALIASES}${TARGET_ALIAS}:" - done - TARGET_ALIASES="${TARGET_ALIASES}:" - - # Filter the list of tests obtained from listing files with the bazel query - # results. - TESTS_PASSED_FILTER="" - TESTS_BLOCKED_BY_FILTER="" - for PY_TEST in ${INPUT_TESTS}; do - # Transform, e.g., tensorflow/python/kernel_tests/xent_op_test.py --> - # python-xent_op_test - PY_TEST_PATH_1=$(echo "${PY_TEST}" | sed "s/\// /g" | awk '{print $2}') - PY_TEST_BASE_NAME=$(echo "${PY_TEST}" | sed "s/\// /g" \ - | awk '{print $NF}' | sed "s/\.py//g") - PY_TEST_ALIAS="${PY_TEST_PATH_1}-${PY_TEST_BASE_NAME}" - - TO_INCLUDE=0 - if [[ "${TARGET_ALIASES}" == *"${PY_TEST_ALIAS}"* ]] && \ - [[ "${FILTER_MODE}" == "--inclusive" ]]; then - TO_INCLUDE=1 - elif [[ "${TARGET_ALIASES}" != *"${PY_TEST_ALIAS}"* ]] && \ - [[ "${FILTER_MODE}" == "--exclusive" ]]; then - TO_INCLUDE=1 - fi - - if [[ ${TO_INCLUDE} == 1 ]]; then - TESTS_PASSED_FILTER="${TESTS_PASSED_FILTER} ${PY_TEST}" - else - TESTS_BLOCKED_BY_FILTER="${TESTS_BLOCKED_BY_FILTER} ${PY_TEST}" - fi - done - - echo "${TESTS_PASSED_FILTER} | ${TESTS_BLOCKED_BY_FILTER}" -} - - -if [[ ${TF_BUILD_FILTER_INSTALL_TESTS_BY_TAG} != "" ]]; then - # Iteratively apply the filter tags - TAGS=(${TF_BUILD_FILTER_INSTALL_TESTS_BY_TAG//,/ }) - for TAG in ${TAGS[@]}; do - if [[ ${TAG} == "-"* ]]; then - MODE="--exclusive" - TAG_1=$(echo ${TAG} | sed 's/-//') - else - MODE="--inclusive" - TAG_1=${TAG} - fi - - FILTER_OUTPUT=$(filter_tests_by_bazel_tag ${MODE} \ - "${TAG_1}" "${ALL_PY_TESTS_0}") - ALL_PY_TESTS_0=$(echo "${FILTER_OUTPUT}" | cut -d \| -f 1) - DISCARDED_TESTS=$(echo "${FILTER_OUTPUT}" | cut -d \| -f 2) - N_DISCARDED=$(echo "${DISCARDED_TESTS}" | wc -w) - - echo "" - echo "Skipping ${N_DISCARDED} test(s) due to filter tag \"${TAG}\":" - echo "${DISCARDED_TESTS}" - echo "" - done -fi - -# Move the exclusive tests to the back of the list -EXCLUSIVE_LIST="$(echo "${PY_TEST_EXCLUSIVE_LIST}" | sed -e 's/:/ /g')" - -ALL_PY_TESTS="" -for TEST in ${ALL_PY_TESTS_0}; do - if [[ ! ${PY_TEST_EXCLUSIVE_LIST} == *"${TEST}"* ]]; then - ALL_PY_TESTS="${ALL_PY_TESTS} ${TEST}" - fi -done - -# Number of parallel (non-exclusive) tests -N_PAR_TESTS=$(echo ${ALL_PY_TESTS} | wc -w) -echo "Number of non-exclusive tests: ${N_PAR_TESTS}" - -for TEST in ${EXCLUSIVE_LIST}; do - ALL_PY_TESTS="${ALL_PY_TESTS} ${TEST}" -done - -PY_TEST_COUNT=$(echo ${ALL_PY_TESTS} | wc -w) - -if [[ ${PY_TEST_COUNT} -eq 0 ]]; then - die "ERROR: Cannot find any tensorflow Python unit tests to run on install" -fi - -# Iterate through all the Python unit test files using the installation -TEST_COUNTER=0 -PASS_COUNTER=0 -FAIL_COUNTER=0 -SKIP_COUNTER=0 -FAILED_TESTS="" -FAILED_TEST_LOGS="" - -if [[ "${IS_GPU}" == "1" ]]; then - if [[ "${IS_MAC}" == "1" ]]; then - N_JOBS=1 - else - N_JOBS=$TF_GPU_COUNT - fi -else - N_JOBS=$(grep -c ^processor /proc/cpuinfo) - if [[ -z ${N_JOBS} ]]; then - # Try the Mac way of getting number of CPUs - N_JOBS=$(sysctl -n hw.ncpu) - fi - - # If still cannot determine the number of CPUs, pick 8. - if [[ -z ${N_JOBS} ]]; then - N_JOBS=8 - echo "Cannot determine the number of processors" - echo "Using default concurrent job counter ${N_JOBS}" - fi -fi - -echo "Running Python tests-on-install with ${N_JOBS} concurrent jobs..." - -ALL_PY_TESTS=(${ALL_PY_TESTS}) -while true; do - TEST_LOGS="" - TEST_INDICES="" - TEST_FILE_PATHS="" - TEST_BASENAMES="" - - ITER_COUNTER=0 - while true; do - # Break if the end is reached - if [[ "${TEST_COUNTER}" -ge "${PY_TEST_COUNT}" ]]; then - break; - fi - - # for TEST_FILE_PATH in ${ALL_PY_TESTS}; do - TEST_FILE_PATH=${ALL_PY_TESTS[TEST_COUNTER]} - - ((TEST_COUNTER++)) - ((ITER_COUNTER++)) - - # If PY_TEST_WHITELIST is not empty, only the white-listed tests will be run - if [[ ! -z ${PY_TEST_WHITELIST} ]] && \ - [[ ! ${PY_TEST_WHITELIST} == *"${TEST_FILE_PATH}"* ]]; then - ((SKIP_COUNTER++)) - echo "Non-whitelisted test SKIPPED: ${TEST_FILE_PATH}" - - continue - fi - - # If the test is in the black list, skip it - if [[ ${PY_TEST_BLACKLIST} == *"${TEST_FILE_PATH}"* ]]; then - ((SKIP_COUNTER++)) - echo "Blacklisted test SKIPPED: ${TEST_FILE_PATH}" - continue - fi - - TEST_INDICES="${TEST_INDICES} ${TEST_COUNTER}" - TEST_FILE_PATHS="${TEST_FILE_PATHS} ${TEST_FILE_PATH}" - - # Copy to a separate directory to guard against the possibility of picking - # up modules in the source directory - cp ${TEST_FILE_PATH} ${PY_TEST_DIR}/ - - TEST_BASENAME=$(basename "${TEST_FILE_PATH}") - TEST_BASENAMES="${TEST_BASENAMES} ${TEST_BASENAME}" - - # Relative path of the test log. Use long path in case there are duplicate - # file names in the Python tests - TEST_LOG_REL="${PY_TEST_LOG_DIR_REL}/${TEST_FILE_PATH}.log" - mkdir -p $(dirname ${TEST_LOG_REL}) # Create directory for log - - TEST_LOG=$(realpath ${TEST_LOG_REL}) # Absolute path - TEST_LOGS="${TEST_LOGS} ${TEST_LOG}" - - # Launch test asynchronously - if [[ "${IS_GPU}" == "1" ]] && [[ "${IS_MAC}" == "0" ]]; then - # Only use this script without mac. This uses flock, which is not - # available in MacOSX. - "${SCRIPT_DIR}/../gpu_build/parallel_gpu_execute.sh" \ - "${SCRIPT_DIR}/py_test_delegate.sh" \ - "${PYTHON_BIN_PATH}" "${PY_TEST_DIR}/${TEST_BASENAME}" "${TEST_LOG}" & - else - "${SCRIPT_DIR}/py_test_delegate.sh" \ - "${PYTHON_BIN_PATH}" "${PY_TEST_DIR}/${TEST_BASENAME}" "${TEST_LOG}" & - fi - - if [[ "${TEST_COUNTER}" -ge "${N_PAR_TESTS}" ]]; then - # Run in exclusive mode - if [[ "${TEST_COUNTER}" -gt "${N_PAR_TESTS}" ]]; then - echo "Run test exclusively: ${PY_TEST_DIR}/${TEST_BASENAME}" - fi - break - fi - - if [[ "${ITER_COUNTER}" -ge "${N_JOBS}" ]] || - [[ "${TEST_COUNTER}" -ge "${PY_TEST_COUNT}" ]]; then - break - fi - - done - - # Wait for all processes to complete - wait - - TEST_LOGS=(${TEST_LOGS}) - TEST_FILE_PATHS=(${TEST_FILE_PATHS}) - TEST_BASENAMES=(${TEST_BASENAMES}) - - K=0 - for TEST_INDEX in ${TEST_INDICES}; do - TEST_FILE_PATH=${TEST_FILE_PATHS[K]} - TEST_RESULT=$(tail -1 "${TEST_LOGS[K]}" | awk '{print $1}') - ELAPSED_TIME=$(tail -1 "${TEST_LOGS[K]}" | cut -d' ' -f2-) - - PROG_STR="(${TEST_INDEX} / ${PY_TEST_COUNT})" - # Check for pass or failure status of the test outtput and exit - if [[ ${TEST_RESULT} -eq 0 ]]; then - ((PASS_COUNTER++)) - - echo "${PROG_STR} Python test-on-install PASSED (${ELAPSED_TIME}): ${TEST_FILE_PATH}" - else - ((FAIL_COUNTER++)) - - FAILED_TESTS="${FAILED_TESTS} ${TEST_FILE_PATH}" - FAILED_TEST_LOGS="${FAILED_TEST_LOGS} ${TEST_LOGS[K]}" - - echo "${PROG_STR} Python test-on-install FAILED (${ELAPSED_TIME}): ${TEST_FILE_PATH}" - - echo " Log @: ${TEST_LOGS[K]}" - echo "============== BEGINS failure log content ==============" - "${HEAD_BIN}" --lines=-1 "${TEST_LOGS[K]}" - echo "============== ENDS failure log content ==============" - echo "" - fi - cd ${DIR0} - - # Clean up files for this test - rm -f ${TEST_BASENAMES[K]} - - ((K++)) - done - - # Stop if the end is reached - if [[ "${TEST_COUNTER}" -ge "${PY_TEST_COUNT}" ]]; then - break; - fi -done - -# Clean up files copied for Python unit tests: -rm -rf ${TF_INSTALL_PATH}/python/tools -rm -rf ${TF_INSTALL_PATH}/examples/image_retraining -rm -rf ${PY_TEST_DIR}/tensorflow/core/lib/jpeg -rm -rf ${PY_TEST_DIR}/tensorflow/core/lib/png -rm -rf ${PY_TEST_DIR}/testdata - -echo "" -echo "${PY_TEST_COUNT} Python test(s):" \ - "${PASS_COUNTER} passed;" \ - "${FAIL_COUNTER} failed; " \ - "${SKIP_COUNTER} skipped" -echo "Test logs directory: ${PY_TEST_LOG_DIR_REL}" - -if [[ ${FAIL_COUNTER} -eq 0 ]]; then - echo "" - echo "Python test-on-install SUCCEEDED" - - exit 0 -else - echo "FAILED test(s):" - FAILED_TEST_LOGS=($FAILED_TEST_LOGS) - FAIL_COUNTER=0 - for TEST_NAME in ${FAILED_TESTS}; do - echo " ${TEST_NAME} (Log @: ${FAILED_TEST_LOGS[${FAIL_COUNTER}]})" - ((FAIL_COUNTER++)) - done - - echo "" - echo "Python test-on-install FAILED" - exit 1 -fi diff --git a/tensorflow/tools/ci_build/builds/test_user_ops.sh b/tensorflow/tools/ci_build/builds/test_user_ops.sh index 216abbe8e67..3b7e2348ad1 100755 --- a/tensorflow/tools/ci_build/builds/test_user_ops.sh +++ b/tensorflow/tools/ci_build/builds/test_user_ops.sh @@ -123,7 +123,7 @@ if [[ ${IS_GPU} == "0" ]]; then EXPECTED_OUTPUT="[42, 0, 0]" # Locate the op kernel C++ file - OP_KERNEL_CC="${SCRIPT_DIR}/../../../g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc" + OP_KERNEL_CC="${SCRIPT_DIR}/user_ops/zero_out_op_kernel_1.cc" OP_KERNEL_CC=$(realpath "${OP_KERNEL_CC}") if [[ ! -f "${OP_KERNEL_CC}" ]]; then @@ -162,13 +162,13 @@ else "${NVCC_BIN}" --version echo "" - OP_KERNEL_CU="${SCRIPT_DIR}/../../../g3doc/how_tos/adding_an_op/cuda_op_kernel.cu.cc" + OP_KERNEL_CU="${SCRIPT_DIR}/user_ops/cuda_op_kernel.cu.cc" OP_KERNEL_CU=$(realpath "${OP_KERNEL_CU}") if [[ ! -f "${OP_KERNEL_CU}" ]]; then die "ERROR: Unable to find user-op kernel CUDA file at: ${OP_KERNEL_CU}" fi - OP_KERNEL_CC="${SCRIPT_DIR}/../../../g3doc/how_tos/adding_an_op/cuda_op_kernel.cc" + OP_KERNEL_CC="${SCRIPT_DIR}/user_ops/cuda_op_kernel.cc" OP_KERNEL_CC=$(realpath "${OP_KERNEL_CC}") if [[ ! -f "${OP_KERNEL_CC}" ]]; then die "ERROR: Unable to find user-op kernel C++ file at: ${OP_KERNEL_CC}" diff --git a/tensorflow/g3doc/how_tos/adding_an_op/cuda_op_kernel.cc b/tensorflow/tools/ci_build/builds/user_ops/cuda_op_kernel.cc similarity index 100% rename from tensorflow/g3doc/how_tos/adding_an_op/cuda_op_kernel.cc rename to tensorflow/tools/ci_build/builds/user_ops/cuda_op_kernel.cc diff --git a/tensorflow/tools/ci_build/builds/user_ops/cuda_op_kernel.cu.cc b/tensorflow/tools/ci_build/builds/user_ops/cuda_op_kernel.cu.cc new file mode 100644 index 00000000000..65b50bd3ae9 --- /dev/null +++ b/tensorflow/tools/ci_build/builds/user_ops/cuda_op_kernel.cu.cc @@ -0,0 +1,31 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +__global__ void AddOneKernel(const int* in, const int N, int* out) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + out[i] = in[i] + 1; + } +} + +void AddOneKernelLauncher(const int* in, const int N, int* out) { + AddOneKernel<<<32, 256>>>(in, N, out); +} + +#endif diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc b/tensorflow/tools/ci_build/builds/user_ops/zero_out_op_kernel_1.cc similarity index 100% rename from tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc rename to tensorflow/tools/ci_build/builds/user_ops/zero_out_op_kernel_1.cc diff --git a/tensorflow/tools/ci_build/ci_build.sh b/tensorflow/tools/ci_build/ci_build.sh index 3697fd46a0e..9c1b75d0048 100755 --- a/tensorflow/tools/ci_build/ci_build.sh +++ b/tensorflow/tools/ci_build/ci_build.sh @@ -18,7 +18,7 @@ # # # CONTAINER_TYPE: Type of the docker container used the run the build: -# e.g., (cpu | gpu | android | tensorboard) +# e.g., (cpu | gpu | gpu_clang | android | tensorboard) # # DOCKERFILE_PATH: (Optional) Path to the Dockerfile used for docker build. # If this optional value is not supplied (via the @@ -26,7 +26,7 @@ # directory as this script will be used. # # COMMAND: Command to be executed in the docker container, e.g., -# tensorflow/tools/ci_build/builds/pip.sh gpu +# tensorflow/tools/ci_build/builds/pip.sh gpu -c opt --config=cuda SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/builds/builds_common.sh" @@ -80,11 +80,11 @@ fi # cmake (CPU) builds do not require configuration. if [[ "${CONTAINER_TYPE}" == "cmake" ]]; then - CI_COMMAND_PREFIX="" + CI_COMMAND_PREFIX=("") fi # Use nvidia-docker if the container is GPU. -if [[ "${CONTAINER_TYPE}" == "gpu" ]]; then +if [[ "${CONTAINER_TYPE}" == "gpu" ]] || [[ "${CONTAINER_TYPE}" == "gpu_clang" ]]; then DOCKER_BINARY="nvidia-docker" else DOCKER_BINARY="docker" @@ -104,7 +104,7 @@ BUILD_TAG="${BUILD_TAG:-tf_ci}" # Add extra params for cuda devices and libraries for GPU container. # And clear them if we are not building for GPU. -if [ "${CONTAINER_TYPE}" != "gpu" ]; then +if [[ "${CONTAINER_TYPE}" != "gpu" ]] && [[ "${CONTAINER_TYPE}" != "gpu_clang" ]]; then GPU_EXTRA_PARAMS="" fi @@ -120,9 +120,9 @@ DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]') # Print arguments. echo "WORKSPACE: ${WORKSPACE}" -echo "CI_DOCKER_EXTRA_PARAMS: ${CI_DOCKER_EXTRA_PARAMS[@]}" -echo "COMMAND: ${COMMAND[@]}" -echo "CI_COMMAND_PREFIX: ${CI_COMMAND_PREFIX[@]}" +echo "CI_DOCKER_EXTRA_PARAMS: ${CI_DOCKER_EXTRA_PARAMS[*]}" +echo "COMMAND: ${COMMAND[*]}" +echo "CI_COMMAND_PREFIX: ${CI_COMMAND_PREFIX[*]}" echo "CONTAINER_TYPE: ${CONTAINER_TYPE}" echo "BUILD_TAG: ${BUILD_TAG}" echo " (docker container name will be ${DOCKER_IMG_NAME})" @@ -140,7 +140,7 @@ if [[ $? != "0" ]]; then fi # Run the command inside the container. -echo "Running '${COMMAND[@]}' inside ${DOCKER_IMG_NAME}..." +echo "Running '${COMMAND[*]}' inside ${DOCKER_IMG_NAME}..." mkdir -p ${WORKSPACE}/bazel-ci_build-cache # By default we cleanup - remove the container once it finish running (--rm) # and share the PID namespace (--pid=host) so the process inside does not have diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 4fd1277d63b..1cf87d7c7c0 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -18,7 +18,7 @@ # ci_parameterized_build.sh # # The script obeys the following required environment variables: -# TF_BUILD_CONTAINER_TYPE: (CPU | GPU | ANDROID | ANDROID_FULL) +# TF_BUILD_CONTAINER_TYPE: (CPU | GPU | GPU_CLANG | ANDROID | ANDROID_FULL) # TF_BUILD_PYTHON_VERSION: (PYTHON2 | PYTHON3 | PYTHON3.5) # TF_BUILD_IS_PIP: (NO_PIP | PIP | BOTH) # @@ -84,10 +84,14 @@ # support for Google Cloud Platform (GCP), which is # enabled by default. # TF_BUILD_OPTIONS: -# (FASTBUILD | OPT | OPTDBG | MAVX | MAVX2) +# (FASTBUILD | OPT | OPTDBG | MAVX | MAVX2_FMA | MAVX_DBG | +# MAVX2_FMA_DBG) # Use the specified configurations when building. # When set, overrides TF_BUILD_IS_OPT and TF_BUILD_MAVX # options, as this will replace the two. +# TF_SKIP_CONTRIB_TESTS: +# If set to any non-empty or non-0 value, will skipp running +# contrib tests. # # This script can be used by Jenkins parameterized / matrix builds. @@ -117,8 +121,7 @@ DOCKER_MAIN_CMD="${CI_BUILD_DIR}/ci_build.sh" NO_DOCKER_MAIN_CMD="${CI_BUILD_DIR}/builds/configured" # Additional option flags to apply when Docker is unavailable (e.g., on Mac) -NO_DOCKER_OPT_FLAG="--linkopt=-headerpad_max_install_names "\ -"--genrule_strategy=standalone" +NO_DOCKER_OPT_FLAG="--genrule_strategy=standalone" DO_DOCKER=1 @@ -147,6 +150,10 @@ else EXTRA_PARAMS="${EXTRA_PARAMS} -e TF_BUILD_ENABLE_XLA=1" fi +if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then + BAZEL_TARGET="$BAZEL_TARGET -//tensorflow/contrib/..." +fi + TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data" ########################################################## @@ -193,8 +200,8 @@ echo " TF_BUILD_ENABLE_XLA=${TF_BUILD_ENABLE_XLA}" function get_cuda_capability_version() { if [[ ! -z $(which deviceQuery) ]]; then # The first listed device is used - echo $(deviceQuery | grep "CUDA Capability .* version" | \ - head -1 | awk '{print $NF}') + deviceQuery | grep "CUDA Capability .* version" | \ + head -1 | awk '{print $NF}' fi } @@ -217,8 +224,13 @@ fi # Process container type if [[ ${CTYPE} == "cpu" ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then : -elif [[ ${CTYPE} == "gpu" ]]; then - OPT_FLAG="${OPT_FLAG} --config=cuda" +elif [[ ${CTYPE} == "gpu" ]] || [[ ${CTYPE} == "gpu_clang" ]]; then + if [[ ${CTYPE} == "gpu" ]]; then + OPT_FLAG="${OPT_FLAG} --config=cuda" + else # ${CTYPE} == "gpu_clang" + OPT_FLAG="${OPT_FLAG} --config=cuda_clang" + fi + # Attempt to determine CUDA capability version automatically and use it if # CUDA capability version is not specified by the environment variables. @@ -305,11 +317,14 @@ else MAVX) OPT_FLAG="${OPT_FLAG} -c opt --copt=-mavx" ;; - MAVXDBG) + MAVX_DBG) OPT_FLAG="${OPT_FLAG} -c opt --copt=-g --copt=-mavx" ;; - MAVX2) - OPT_FLAG="${OPT_FLAG} -c opt --copt=-mavx2" + MAVX2_FMA) + OPT_FLAG="${OPT_FLAG} -c opt --copt=-mavx2 --copt=-mfma" + ;; + MAVX2_FMA_DBG) + OPT_FLAG="${OPT_FLAG} -c opt --copt=-g --copt=-mavx2 --copt=-mfma" ;; esac fi @@ -318,21 +333,35 @@ fi OPT_FLAG=$(str_strip "${OPT_FLAG}") -# Filter out benchmark tests if this is not a benchmarks job +# 1) Filter out benchmark tests if this is not a benchmarks job; +# 2) Filter out tests with the "nomac" tag if the build is on Mac OS X. EXTRA_ARGS="" +IS_MAC=0 +if [[ "$(uname)" == "Darwin" ]]; then + IS_MAC=1 +fi if [[ "${TF_BUILD_APPEND_ARGUMENTS}" == *"--test_tag_filters="* ]]; then ITEMS=(${TF_BUILD_APPEND_ARGUMENTS}) for ITEM in "${ITEMS[@]}"; do - if [[ ${ITEM} == *"--test_tag_filters="* ]] && - [[ ${ITEM} != *"benchmark-test"* ]]; then - EXTRA_ARGS="${EXTRA_ARGS} ${ITEM},-benchmark-test" + if [[ ${ITEM} == *"--test_tag_filters="* ]]; then + NEW_ITEM="${ITEM}" + if [[ ${NEW_ITEM} != *"benchmark-test"* ]]; then + NEW_ITEM="${NEW_ITEM},-benchmark-test" + fi + if [[ ${IS_MAC} == "1" ]] && [[ ${NEW_ITEM} != *"nomac"* ]]; then + NEW_ITEM="${NEW_ITEM},-nomac" + fi + EXTRA_ARGS="${EXTRA_ARGS} ${NEW_ITEM}" else EXTRA_ARGS="${EXTRA_ARGS} ${ITEM}" fi done else EXTRA_ARGS="${TF_BUILD_APPEND_ARGUMENTS} --test_tag_filters=-benchmark-test" + if [[ ${IS_MAC} == "1" ]]; then + EXTRA_ARGS="${EXTRA_ARGS},-nomac" + fi fi # For any "tool" dependencies in genrules, Bazel will build them for host @@ -353,7 +382,7 @@ if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || # CPU only command, fully parallel. NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${OPT_FLAG} ${EXTRA_ARGS} -- "\ "${BAZEL_TARGET}" - elif [[ ${CTYPE} == "gpu" ]]; then + elif [[ ${CTYPE} == "gpu" ]] || [[ ${CTYPE} == "gpu_clang" ]]; then # GPU only command, run as many jobs as the GPU count only. NO_PIP_MAIN_CMD="${BAZEL_CMD} ${OPT_FLAG} "\ "--local_test_jobs=${TF_GPU_COUNT} "\ @@ -377,12 +406,7 @@ if [[ ${TF_BUILD_IS_PIP} == "pip" ]] || exit 0 fi - PIP_MAIN_CMD="${MAIN_CMD} ${PIP_CMD} ${CTYPE} ${EXTRA_AGRS}" - - # Add flag for mavx/mavx2 - if [[ ! -z "${TF_BUILD_MAVX}" ]]; then - PIP_MAIN_CMD="${PIP_MAIN_CMD} --${TF_BUILD_MAVX}" - fi + PIP_MAIN_CMD="${MAIN_CMD} ${PIP_CMD} ${CTYPE} ${EXTRA_ARGS} ${OPT_FLAG}" # Add flag for integration tests if [[ ! -z "${TF_BUILD_INTEGRATION_TESTS}" ]] && @@ -424,7 +448,8 @@ if [[ ${TF_BUILD_PYTHON_VERSION} == "python2" ]]; then : elif [[ ${TF_BUILD_PYTHON_VERSION} == "python3" || \ ${TF_BUILD_PYTHON_VERSION} == "python3.4" || \ - ${TF_BUILD_PYTHON_VERSION} == "python3.5" ]]; then + ${TF_BUILD_PYTHON_VERSION} == "python3.5" || \ + ${TF_BUILD_PYTHON_VERSION} == "python3.6" ]]; then # Supply proper environment variable to select Python 3 if [[ "${DO_DOCKER}" == "1" ]]; then EXTRA_PARAMS="${EXTRA_PARAMS} -e CI_BUILD_PYTHON=${TF_BUILD_PYTHON_VERSION}" @@ -507,11 +532,14 @@ if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ]]; then DOCKERFILE="${TMP_DIR}/Dockerfile.${TF_BUILD_CONTAINER_TYPE}" # Replace a line in the Dockerfile - sed -i \ + if sed -i \ 's/RUN \/install\/install_pip_packages.sh/RUN \/install\/install_python3.5_pip_packages.sh/g' \ - "${DOCKERFILE}" && \ - echo "Copied and modified Dockerfile for Python 3.5 build: ${DOCKERFILE}" || \ - die "ERROR: Faild to copy and modify Dockerfile: ${DOCKERFILE}" + "${DOCKERFILE}" + then + echo "Copied and modified Dockerfile for Python 3.5 build: ${DOCKERFILE}" + else + die "ERROR: Faild to copy and modify Dockerfile: ${DOCKERFILE}" + fi DOCKERFILE_FLAG="--dockerfile ${DOCKERFILE}" fi @@ -549,7 +577,7 @@ rm -f ${TMP_SCRIPT} END_TIME=$(date +'%s') echo "" echo "Parameterized build ends with ${RESULT} at: $(date) "\ -"(Elapsed time: $((${END_TIME} - ${START_TIME})) s)" +"(Elapsed time: $((END_TIME - START_TIME)) s)" # Clean up temporary directory if it exists diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 975a14e7d51..e428766a400 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -47,7 +47,7 @@ num_cpus() { # Get the hash of the last non-merge git commit on the current branch. # Usage: get_last_non_merge_git_commit get_last_non_merge_git_commit() { - echo $(git rev-list --no-merges -n 1 HEAD) + git rev-list --no-merges -n 1 HEAD } # List files changed (i.e., added, removed or revised) in the last non-merge @@ -75,7 +75,7 @@ get_py_files_to_check() { echo "${PY_FILES}" else - echo $(find tensorflow -name '*.py') + find tensorflow -name '*.py' fi } @@ -92,6 +92,8 @@ do_pylint() { ERROR_WHITELIST="^tensorflow/python/framework/function_test\.py.*\[E1123.*noinline "\ "^tensorflow/python/platform/default/_gfile\.py.*\[E0301.*non-iterator "\ "^tensorflow/python/platform/default/_googletest\.py.*\[E0102.*function\salready\sdefined "\ +"^tensorflow/python/feature_column/feature_column_test\.py.*\[E0110.*abstract-class-instantiated "\ +"^tensorflow/contrib/layers/python/layers/feature_column\.py.*\[E0110.*abstract-class-instantiated "\ "^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator" echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\"" @@ -155,25 +157,25 @@ do_pylint() { NONWL_ERRORS_FILE="$(mktemp)_pylint_nonwl_errors.log" rm -rf ${OUTPUT_FILE} - rm -rf ${ERRORS_FLIE} + rm -rf ${ERRORS_FILE} rm -rf ${NONWL_ERRORS_FILE} touch ${NONWL_ERRORS_FILE} ${PYLINT_BIN} --rcfile="${PYLINTRC_FILE}" --output-format=parseable \ - --jobs=${NUM_CPUS} ${PYTHON_SRC_FILES} 2>&1 > ${OUTPUT_FILE} + --jobs=${NUM_CPUS} ${PYTHON_SRC_FILES} > ${OUTPUT_FILE} 2>&1 PYLINT_END_TIME=$(date +'%s') echo "" - echo "pylint took $((${PYLINT_END_TIME} - ${PYLINT_START_TIME})) s" + echo "pylint took $((PYLINT_END_TIME - PYLINT_START_TIME)) s" echo "" grep -E '(\[E|\[W0311|\[W0312)' ${OUTPUT_FILE} > ${ERRORS_FILE} N_ERRORS=0 - while read LINE; do + while read -r LINE; do IS_WHITELISTED=0 for WL_REGEX in ${ERROR_WHITELIST}; do - if [[ ! -z $(echo ${LINE} | grep "${WL_REGEX}") ]]; then + if echo ${LINE} | grep -q "${WL_REGEX}"; then echo "Found a whitelisted error:" echo " ${LINE}" IS_WHITELISTED=1 @@ -246,7 +248,7 @@ do_pep8() { PEP8_END_TIME=$(date +'%s') echo "" - echo "pep8 took $((${PEP8_END_TIME} - ${PEP8_START_TIME})) s" + echo "pep8 took $((PEP8_END_TIME - PEP8_START_TIME)) s" echo "" if [[ -s ${PEP8_OUTPUT_FILE} ]]; then @@ -276,7 +278,7 @@ do_buildifier(){ BUILDIFIER_END_TIME=$(date +'%s') echo "" - echo "buildifier took $((${BUILDIFIER_END_TIME} - ${BUILDIFIER_START_TIME})) s" + echo "buildifier took $((BUILDIFIER_END_TIME - BUILDIFIER_START_TIME)) s" echo "" if [[ -s ${BUILDIFIER_OUTPUT_FILE} ]]; then @@ -304,7 +306,7 @@ do_external_licenses_check(){ echo "Getting external dependencies for ${BUILD_TARGET}" bazel query "attr('licenses', 'notice', deps(${BUILD_TARGET}))" --no_implicit_deps --no_host_deps --keep_going \ - | egrep -v "^//tensorflow" \ + | grep -E -v "^//tensorflow" \ | sed -e 's|:.*||' \ | sort \ | uniq 2>&1 \ @@ -313,7 +315,7 @@ do_external_licenses_check(){ echo echo "Getting list of external licenses mentioned in ${LICENSES_TARGET}." bazel query "deps(${LICENSES_TARGET})" --no_implicit_deps --no_host_deps --keep_going \ - | egrep -v "^//tensorflow" \ + | grep -E -v "^//tensorflow" \ | sed -e 's|:.*||' \ | sort \ | uniq 2>&1 \ @@ -327,7 +329,7 @@ do_external_licenses_check(){ EXTERNAL_LICENSES_CHECK_END_TIME=$(date +'%s') echo - echo "do_external_licenses_check took $((${EXTERNAL_LICENSES_CHECK_END_TIME} - ${EXTERNAL_LICENSES_CHECK_START_TIME})) s" + echo "do_external_licenses_check took $((EXTERNAL_LICENSES_CHECK_END_TIME - EXTERNAL_LICENSES_CHECK_START_TIME)) s" echo if [[ -s ${MISSING_LICENSES_FILE} ]] || [[ -s ${EXTRA_LICENSES_FILE} ]] ; then @@ -371,17 +373,20 @@ do_lib_package_licenses_check() { "//tensorflow/tools/lib_package:clicenses_generate" } -# Run bazel build --nobuild to test the validity of the BUILD files -do_bazel_nobuild() { - BUILD_TARGET="//tensorflow/..." - BUILD_CMD="bazel build --nobuild ${BUILD_TARGET}" - - ${BUILD_CMD} +do_java_package_licenses_check() { + echo "Running do_java_package_licenses_check" + echo "" + do_external_licenses_check \ + "//tensorflow/java:libtensorflow_jni.so" \ + "//tensorflow/tools/lib_package:jnilicenses_generate" +} +#Check for the bazel cmd status (First arg is error message) +cmd_status(){ if [[ $? != 0 ]]; then echo "" echo "FAIL: ${BUILD_CMD}" - echo " This is due to invalid BUILD files. See lines above for details." + echo " $1 See lines above for details." return 1 else echo "" @@ -390,9 +395,32 @@ do_bazel_nobuild() { fi } +# Run bazel build --nobuild to test the validity of the BUILD files +do_bazel_nobuild() { + BUILD_TARGET="//tensorflow/..." + BUILD_CMD="bazel build --nobuild ${BUILD_TARGET}" + + ${BUILD_CMD} + + cmd_status \ + "This is due to invalid BUILD files." +} + +do_pip_smoke_test() { + BUILD_CMD="bazel build //tensorflow/tools/pip_package:pip_smoke_test" + ${BUILD_CMD} + cmd_status \ + "Pip smoke test has failed. Please make sure any new TensorFlow are added to the tensorflow/tools/pip_package:build_pip_package dependencies." + + RUN_CMD="bazel-bin/tensorflow/tools/pip_package/pip_smoke_test" + ${RUN_CMD} + cmd_status \ + "The pip smoke test failed." +} + # Supply all sanity step commands and descriptions -SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check") -SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies") +SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test") +SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package") INCREMENTAL_FLAG="" diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh index c33ea2d5cc6..6e7b752c06f 100755 --- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh +++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh @@ -24,6 +24,17 @@ # TF_GPU_COUNT = Number of GPUs available. This HAS TO BE IN SYNC with the # value of --local_test_jobs flag for bazel. +BASH_VER_MAJOR=$(echo ${BASH_VERSION} | cut -d '.' -f 1) +BASH_VER_MINOR=$(echo ${BASH_VERSION} | cut -d '.' -f 2) + +if [[ ${BASH_VER_MAJOR} -lt 4 ]]; then + echo "Insufficient bash version: ${BASH_VERSION} < 4.2" >&2 + exit 1 +elif [[ ${BASH_VER_MAJOR} -eq 4 ]] && [[ ${BASH_VER_MINOR} -lt 2 ]]; then + echo "Insufficient bash version: ${BASH_VERSION} < 4.2" >&2 + exit 1 +fi + TF_GPU_COUNT=${TF_GPU_COUNT:-8} for i in `seq 0 $((TF_GPU_COUNT-1))`; do diff --git a/tensorflow/tools/ci_build/install/build_and_install_clang.sh b/tensorflow/tools/ci_build/install/build_and_install_clang.sh new file mode 100755 index 00000000000..3fb99649485 --- /dev/null +++ b/tensorflow/tools/ci_build/install/build_and_install_clang.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex + +LLVM_SVN_REVISION="299268" +CLANG_TMP_DIR=/tmp/clang-build + +mkdir "$CLANG_TMP_DIR" + +pushd "$CLANG_TMP_DIR" + +# Checkout llvm+clang +svn co -q -r$LLVM_SVN_REVISION http://llvm.org/svn/llvm-project/llvm/trunk "$CLANG_TMP_DIR/llvm" +svn co -q -r$LLVM_SVN_REVISION http://llvm.org/svn/llvm-project/cfe/trunk "$CLANG_TMP_DIR/llvm/tools/clang" + +# Build 1st stage. Compile clang with system compiler +mkdir "$CLANG_TMP_DIR/build-1" +cd "$CLANG_TMP_DIR/build-1" +cmake -G"Unix Makefiles" -DCMAKE_BUILD_TYPE=Release "$CLANG_TMP_DIR/llvm" +make -j `nproc` clang clang-headers + +# Build 2nd stage. Compile clang with clang built in stage 1 +mkdir "$CLANG_TMP_DIR/build-2" +cd "$CLANG_TMP_DIR/build-2" + +CC="$CLANG_TMP_DIR/build-1/bin/clang" \ +CXX="$CLANG_TMP_DIR/build-1/bin/clang++" \ +cmake -G"Unix Makefiles" -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/local "$CLANG_TMP_DIR/llvm" + +make -j `nproc` install-clang install-clang-headers + +popd + +# Cleanup +rm -rf "$CLANG_TMP_DIR" diff --git a/tensorflow/tools/ci_build/install/install_auditwheel.sh b/tensorflow/tools/ci_build/install/install_auditwheel.sh index 2538a393d3f..e6f6124d567 100755 --- a/tensorflow/tools/ci_build/install/install_auditwheel.sh +++ b/tensorflow/tools/ci_build/install/install_auditwheel.sh @@ -16,7 +16,7 @@ set -e -sudo pip3 install auditwheel +sudo pip3 install auditwheel==1.5.0 set +e patchelf_location=$(which patchelf) diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh index 6807202f7e9..daba126f889 100755 --- a/tensorflow/tools/ci_build/install/install_bazel.sh +++ b/tensorflow/tools/ci_build/install/install_bazel.sh @@ -15,7 +15,7 @@ # ============================================================================== # Select bazel version. -BAZEL_VERSION="0.4.2" +BAZEL_VERSION="0.5.0" set +e local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') diff --git a/tensorflow/tools/ci_build/install/install_buildifier.sh b/tensorflow/tools/ci_build/install/install_buildifier.sh index 2f3470881a5..b2dfcf8db76 100755 --- a/tensorflow/tools/ci_build/install/install_buildifier.sh +++ b/tensorflow/tools/ci_build/install/install_buildifier.sh @@ -16,8 +16,9 @@ set -e BUILDIFIER_DIR="buildifier" -rm -rf ${BUILDIFIER_DIR} -git clone https://github.com/bazelbuild/buildifier.git ${BUILDIFIER_DIR} +mkdir ${BUILDIFIER_DIR} +curl -Ls https://github.com/bazelbuild/buildifier/archive/0.4.5.tar.gz | \ + tar -C "${BUILDIFIER_DIR}" --strip-components=1 -xz pushd ${BUILDIFIER_DIR} bazel build buildifier:buildifier --spawn_strategy=standalone --genrule_strategy=standalone diff --git a/tensorflow/tools/ci_build/install/install_cmake_for_clang.sh b/tensorflow/tools/ci_build/install/install_cmake_for_clang.sh new file mode 100755 index 00000000000..3e626a69ab5 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_cmake_for_clang.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +CMAKE_URL="https://cmake.org/files/v3.7/cmake-3.7.2-Linux-x86_64.tar.gz" + +wget -O - "${CMAKE_URL}" | tar xzf - -C /usr/local --strip-components=1 diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh index 227b83ab9f6..da1f2199d0d 100755 --- a/tensorflow/tools/ci_build/install/install_deb_packages.sh +++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh @@ -13,11 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# +# Usage: +# ./install_deb_packages [--without_cmake] +# Pass --without_cmake to prevent cmake from being installed with apt-get set -e ubuntu_version=$(cat /etc/issue | grep -i ubuntu | awk '{print $2}' | \ awk -F'.' '{print $1}') +if [[ "$1" != "" ]] && [[ "$1" != "--without_cmake" ]]; then + echo "Unknown argument '$1'" + exit 1 +fi + # Install dependencies from ubuntu deb repository. apt-get update @@ -32,28 +41,38 @@ apt-get install -y --no-install-recommends \ autoconf \ automake \ build-essential \ - cmake \ curl \ ffmpeg \ git \ libcurl4-openssl-dev \ libtool \ + mlocate \ openjdk-8-jdk \ openjdk-8-jre-headless \ pkg-config \ python-dev \ - python-pip \ + python-setuptools \ python-virtualenv \ python3-dev \ - python3-pip \ + python3-setuptools \ rsync \ sudo \ + subversion \ swig \ unzip \ wget \ zip \ zlib1g-dev +# populate the database +updatedb + +if [[ "$1" != "--without_cmake" ]]; then + apt-get install -y --no-install-recommends \ + cmake +fi + + # Install ca-certificates, and update the certificate store. apt-get install -y ca-certificates-java update-ca-certificates -f diff --git a/tensorflow/tools/ci_build/install/install_golang.sh b/tensorflow/tools/ci_build/install/install_golang.sh new file mode 100755 index 00000000000..fef203b8697 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_golang.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex + +GOLANG_URL="https://storage.googleapis.com/golang/go1.7.5.linux-amd64.tar.gz" + +sudo mkdir -p /usr/local +wget -q -O - "${GOLANG_URL}" | sudo tar -C /usr/local -xz diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 8e374df6321..c9867796f3a 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -16,153 +16,75 @@ set -e +# We don't apt-get install so that we can install a newer version of pip. Not +# needed after we upgrade to Ubuntu 16.04 +easy_install -U pip +easy_install3 -U pip + # Install pip packages from whl files to avoid the time-consuming process of # building from source. -pip install wheel +pip2 install wheel pip3 install wheel # Install six. -pip install --upgrade six==1.10.0 +pip2 install --upgrade six==1.10.0 pip3 install --upgrade six==1.10.0 # Install werkzeug. -pip install --upgrade werkzeug==0.11.10 +pip2 install --upgrade werkzeug==0.11.10 pip3 install --upgrade werkzeug==0.11.10 +# Install bleach. html5lib will be picked up as a dependency. +pip2 install --upgrade bleach==1.5.0 +pip3 install --upgrade bleach==1.5.0 + +# Install markdown. +pip2 install --upgrade markdown==2.6.8 +pip3 install --upgrade markdown==2.6.8 + # Install protobuf. -pip install --upgrade protobuf==3.0.0 -pip3 install --upgrade protobuf==3.0.0 +pip2 install --upgrade protobuf==3.3.0 +pip3 install --upgrade protobuf==3.3.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* -set +e -# Use pip to install numpy to a modern version, instead of 1.8.2 that comes -# with apt-get in ubuntu:14.04. -NUMPY_VERSION="1.11.0" -numpy_ver_flat=$(echo $NUMPY_VERSION | sed 's/\.//g' | sed 's/^0*//g') -local_numpy_ver=$(python -c "import numpy; print(numpy.__version__)") -local_numpy_ver_flat=$(echo $local_numpy_ver | sed 's/\.//g' | sed 's/^0*//g') -if [[ -z $local_numpy_ver_flat ]]; then - local_numpy_ver_flat=0 -fi -if (( $local_numpy_ver_flat < $numpy_ver_flat )); then - set -e - wget -q https://pypi.python.org/packages/06/92/3c786303889e6246971ad4c48ac2b4e37a1b1c67c0dc2106dc85cb15c18e/numpy-1.11.0-cp27-cp27mu-manylinux1_x86_64.whl#md5=6ffb66ff78c28c55bfa09a2ceee487df - mv numpy-1.11.0-cp27-cp27mu-manylinux1_x86_64.whl \ - numpy-1.11.0-cp27-none-linux_x86_64.whl - pip install numpy-1.11.0-cp27-none-linux_x86_64.whl - rm numpy-1.11.0-cp27-none-linux_x86_64.whl -fi +# numpy needs to be installed from source to fix segfaults. See: +# https://github.com/tensorflow/tensorflow/issues/6968 +# This workaround isn't needed for Ubuntu 16.04 or later. +pip2 install --no-binary=:all: --upgrade numpy==1.12.0 +pip3 install --no-binary=:all: --upgrade numpy==1.12.0 -set +e -local_numpy_ver=$(python3 -c "import numpy; print(numpy.__version__)") -local_numpy_ver_flat=$(echo $local_numpy_ver | sed 's/\.//g' | sed 's/^0*//g') -if [[ -z $local_numpy_ver_flat ]]; then - local_numpy_ver_flat=0 -fi -if (( $local_numpy_ver_flat < $numpy_ver_flat )); then - set -e - wget -q https://pypi.python.org/packages/ea/ca/5e48a68be496e6f79c3c8d90f7c03ea09bbb154ea4511f5b3d6c825cefe5/numpy-1.11.0-cp34-cp34m-manylinux1_x86_64.whl#md5=08a002aeffa20354aa5045eadb549361 - mv numpy-1.11.0-cp34-cp34m-manylinux1_x86_64.whl \ - numpy-1.11.0-cp34-none-linux_x86_64.whl - pip3 install numpy-1.11.0-cp34-none-linux_x86_64.whl - rm numpy-1.11.0-cp34-none-linux_x86_64.whl -fi +pip2 install scipy==0.18.1 +pip3 install scipy==0.18.1 -# Use pip to install scipy to get the latest version, instead of 0.13 through -# apt-get. -# pip install scipy==0.15.1 -set +e -SCIPY_VERSION="0.15.1" -scipy_ver_flat=$(echo $SCIPY_VERSION | sed 's/\.//g' | sed 's/^0*//g') -local_scipy_ver=$(python -c "import scipy; print(scipy.__version__)") -local_scipy_ver_flat=$(echo $local_scipy_ver | sed 's/\.//g' | sed 's/^0*//g') -if [[ -z $local_scipy_ver_flat ]]; then - local_scipy_ver_flat=0 -fi -if (( $local_scipy_ver_flat < $scipy_ver_flat )); then - set -e - wget -q https://pypi.python.org/packages/00/0f/060ec52cb74dc8df1a7ef1a524173eb0bcd329110404869b392685cfc5c8/scipy-0.15.1-cp27-cp27mu-manylinux1_x86_64.whl#md5=aaac02e6535742ab02f2075129890714 - mv scipy-0.15.1-cp27-cp27mu-manylinux1_x86_64.whl \ - scipy-0.15.1-cp27-none-linux_x86_64.whl - pip install scipy-0.15.1-cp27-none-linux_x86_64.whl - rm scipy-0.15.1-cp27-none-linux_x86_64.whl -fi - -# pip3 install scipy==0.15.1 -set +e -local_scipy_ver=$(python3 -c "import scipy; print(scipy.__version__)") -local_scipy_ver_flat=$(echo $local_scipy_ver | sed 's/\.//g' | sed 's/^0*//g') -if [[ -z $local_scipy_ver_flat ]]; then - local_scipy_ver_flat=0 -fi -if (( $local_scipy_ver_flat < $scipy_ver_flat )); then - set -e - wget -q https://pypi.python.org/packages/56/c5/e0d36aaf719aa02ee3da19151045912e240d145586612e53b5eaa706e1db/scipy-0.15.1-cp34-cp34m-manylinux1_x86_64.whl#md5=d5243b0f9d85f4f4cb62514c82af93d4 - mv scipy-0.15.1-cp34-cp34m-manylinux1_x86_64.whl \ - scipy-0.15.1-cp34-cp34m-linux_x86_64.whl - pip3 install scipy-0.15.1-cp34-cp34m-linux_x86_64.whl - rm scipy-0.15.1-cp34-cp34m-linux_x86_64.whl -fi - -# pip install sklearn -set +e -SKLEARN_VERSION="0.17.1" -sklearn_ver_flat=$(echo $SKLEARN_VERSION | sed 's/\.//g' | sed 's/^0*//g') -local_sklearn_ver=$(python -c "import sklearn; print(sklearn.__version__)") -local_sklearn_ver_flat=$(echo $local_sklearn_ver | sed 's/\.//g' | sed 's/^0*//g') -if [[ -z $local_sklearn_ver_flat ]]; then - local_sklearn_ver_flat=0 -fi -if (( $local_sklearn_ver_flat < $sklearn_ver_flat )); then - set -e - wget -q https://pypi.python.org/packages/bf/80/06e77e5a682c46a3880ec487a5f9d910f5c8d919df9aca58052089687c7e/scikit_learn-0.17.1-cp27-cp27mu-manylinux1_x86_64.whl#md5=337b91f502138ba7fd722803138f6dfd - mv scikit_learn-0.17.1-cp27-cp27mu-manylinux1_x86_64.whl \ - scikit_learn-0.17.1-cp27-none-linux_x86_64.whl - pip install scikit_learn-0.17.1-cp27-none-linux_x86_64.whl - rm scikit_learn-0.17.1-cp27-none-linux_x86_64.whl -fi - -# pip3 install scikit-learn -set +e -local_sklearn_ver=$(python3 -c "import sklearn; print(sklearn.__version__)") -local_sklearn_ver_flat=$(echo $local_sklearn_ver | sed 's/\.//g' | sed 's/^0*//g') -if [[ -z $local_sklearn_ver_flat ]]; then - local_sklearn_ver_flat=0 -fi -if (( $local_sklearn_ver_flat < $sklearn_ver_flat )); then - set -e - wget -q https://pypi.python.org/packages/7e/f1/1cc8a1ae2b4de89bff0981aee904ff05779c49a4c660fa38178f9772d3a7/scikit_learn-0.17.1-cp34-cp34m-manylinux1_x86_64.whl#md5=a722a7372b64ec9f7b49a2532d21372b - mv scikit_learn-0.17.1-cp34-cp34m-manylinux1_x86_64.whl \ - scikit_learn-0.17.1-cp34-cp34m-linux_x86_64.whl - pip3 install scikit_learn-0.17.1-cp34-cp34m-linux_x86_64.whl - rm scikit_learn-0.17.1-cp34-cp34m-linux_x86_64.whl -fi - -set -e +pip2 install scikit-learn==0.18.1 +pip3 install scikit-learn==0.18.1 # pandas required by tf.learn/inflow -pip install pandas==0.18.1 -pip3 install pandas==0.18.1 +pip2 install pandas==0.19.2 +pip3 install pandas==0.19.2 # Benchmark tests require the following: -pip install psutil +pip2 install psutil pip3 install psutil -pip install py-cpuinfo +pip2 install py-cpuinfo pip3 install py-cpuinfo # pylint tests require the following: -pip install pylint -pip3 install pylint +pip2 install pylint==1.6.4 +pip3 install pylint==1.6.4 # pep8 tests require the following: -pip install pep8 +pip2 install pep8 pip3 install pep8 # tf.mock require the following for python2: -pip install mock +pip2 install mock -pip install portpicker +pip2 install portpicker pip3 install portpicker + +pip2 install backports.weakref==1.0rc1 +pip3 install backports.weakref==1.0rc1 diff --git a/tensorflow/tools/ci_build/install/install_proto3.sh b/tensorflow/tools/ci_build/install/install_proto3.sh index 773c89b70bb..7934002b2c9 100755 --- a/tensorflow/tools/ci_build/install/install_proto3.sh +++ b/tensorflow/tools/ci_build/install/install_proto3.sh @@ -17,9 +17,9 @@ # Install protobuf3. # Select protobuf version. -PROTOBUF_VERSION="3.2.0" +PROTOBUF_VERSION="3.3.0" protobuf_ver_flat=$(echo $PROTOBUF_VERSION | sed 's/\.//g' | sed 's/^0*//g') -local_protobuf_ver=$(protoc --version | awk '{print $2}') +local_protobuf_ver=$(protoc --version) local_protobuf_ver_flat=$(echo $local_protobuf_ver | sed 's/\.//g' | sed 's/^0*//g') if [[ -z $local_protobuf_ver_flat ]]; then local_protobuf_ver_flat=0 @@ -30,7 +30,7 @@ if (( $local_protobuf_ver_flat < $protobuf_ver_flat )); then PROTOBUF_ZIP=$(basename "${PROTOBUF_URL}") UNZIP_DEST="google-protobuf" - wget -q "${PROTOBUF_URL}" + wget "${PROTOBUF_URL}" unzip "${PROTOBUF_ZIP}" -d "${UNZIP_DEST}" cp "${UNZIP_DEST}/bin/protoc" /usr/local/bin/ diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 0c86db71192..33b3bc104bd 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -64,37 +64,31 @@ set -e pip3.5 install --upgrade six==1.10.0 # Install protobuf. -pip3.5 install --upgrade protobuf==3.0.0 +pip3.5 install --upgrade protobuf==3.3.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* # Install numpy, scipy and scikit-learn required by the builds -pip3.5 install --upgrade numpy -set +e -SCIPY_VERSION="0.17.1" -scipy_ver_flat=$(echo $SCIPY_VERSION | sed 's/\.//g' | sed 's/^0*//g') -local_scipy_ver=$(python3.5 -c "import scipy; print(scipy.__version__)") -local_scipy_ver_flat=$(echo $local_scipy_ver | sed 's/\.//g' | sed 's/^0*//g') -if [[ -z $local_scipy_ver_flat ]]; then - local_scipy_ver_flat=0 -fi -if (( $local_scipy_ver_flat < $scipy_ver_flat )); then - set -e - wget -q https://pypi.python.org/packages/91/f3/0052c245d53eb5f0e13b7215811e52af3791a8a7d31771605697c28466a0/scipy-0.17.1-cp35-cp35m-manylinux1_x86_64.whl#md5=8e77756904c81a6f79ed10e3abf0c544 - pip3.5 install --upgrade scipy-0.17.1-cp35-cp35m-manylinux1_x86_64.whl - rm -f scipy-0.17.1-cp35-cp35m-manylinux1_x86_64.whl -fi +# numpy needs to be installed from source to fix segfaults. See: +# https://github.com/tensorflow/tensorflow/issues/6968 +# This workaround isn't needed for Ubuntu 16.04 or later. +pip3.5 install --no-binary=:all: --upgrade numpy==1.12.0 -set -e -pip3.5 install --upgrade scikit-learn +pip3.5 install scipy==0.18.1 + +pip3.5 install scikit-learn==0.18.1 + +# pandas required by tf.learn/inflow +pip3 install pandas==0.19.2 # Install recent-enough version of wheel for Python 3.5 wheel builds pip3.5 install wheel==0.29.0 -pip3.5 install --upgrade pandas==0.18.1 - pip3.5 install portpicker pip3.5 install werkzeug + +pip3.5 install backports.weakref==1.0rc1 + diff --git a/tensorflow/tools/ci_build/linux/cmake/run.sh b/tensorflow/tools/ci_build/linux/cmake/run.sh old mode 100644 new mode 100755 diff --git a/tensorflow/tools/ci_build/linux/cpu/run.sh b/tensorflow/tools/ci_build/linux/cpu/run.sh deleted file mode 100644 index 4ab545ecb9a..00000000000 --- a/tensorflow/tools/ci_build/linux/cpu/run.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ============================================================================== - -set -e -set -x - -N_JOBS=$(grep -c ^processor /proc/cpuinfo) - -echo "" -echo "Bazel will use ${N_JOBS} concurrent job(s)." -echo "" - -# Run configure. -export TF_NEED_GCP=0 -export TF_NEED_HDFS=0 -export TF_NEED_CUDA=0 -export PYTHON_BIN_PATH=`which python2` -yes "" | ./configure - -# Run bazel test command. Double test timeouts to avoid flakes. -bazel test --test_tag_filters=-gpu --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 //tensorflow/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh new file mode 100755 index 00000000000..467e4ab7e53 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_GCP=0 +export TF_NEED_HDFS=0 +export TF_NEED_CUDA=0 +# Only running cc tests, python version does not matter. +export PYTHON_BIN_PATH=`which python` +yes "" | ./configure + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --test_tag_filters=-gpu,-benchmark-test --test_lang_filters=cc -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ + --test_output=errors -- \ + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh new file mode 100755 index 00000000000..e2bbc0e8c0b --- /dev/null +++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_GCP=0 +export TF_NEED_HDFS=0 +export TF_NEED_CUDA=0 +export PYTHON_BIN_PATH=`which python2` +yes "" | ./configure + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --test_tag_filters=-gpu,-benchmark-test --test_lang_filters=py -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ + --test_output=errors -- \ + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh new file mode 100755 index 00000000000..a03cab0cca5 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_GCP=0 +export TF_NEED_HDFS=0 +export TF_NEED_CUDA=0 +export PYTHON_BIN_PATH=`which python3` +yes "" | ./configure + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --test_tag_filters=-gpu,-benchmark-test -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --test_output=errors -- \ + //tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh new file mode 100755 index 00000000000..32de5cea200 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_GCP=0 +export TF_NEED_HDFS=0 +export TF_NEED_CUDA=0 +export PYTHON_BIN_PATH=`which python3` +yes "" | ./configure + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --test_tag_filters=-gpu,-benchmark-test --test_lang_filters=py -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ + --test_output=errors -- \ + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh new file mode 100755 index 00000000000..6acc2621383 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_GCP=0 +export TF_NEED_HDFS=0 +export PYTHON_BIN_PATH=`which python3` + +export TF_NEED_CUDA=1 +export TF_CUDA_COMPUTE_CAPABILITIES=3.7 + +yes "" | ./configure + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test -k \ + --test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --build_tests_only --test_output=errors --local_test_jobs=8 \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh new file mode 100755 index 00000000000..e73fe046c96 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_GCP=0 +export TF_NEED_HDFS=0 +export PYTHON_BIN_PATH=`which python3` + +export TF_NEED_CUDA=1 +export TF_CUDA_COMPUTE_CAPABILITIES=3.7 + +yes "" | ./configure + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test -k \ + --test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --build_tests_only --test_output=errors --local_test_jobs=8 \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/libtensorflow.sh b/tensorflow/tools/ci_build/linux/libtensorflow.sh index bc64fabde5b..beef8e063b3 100755 --- a/tensorflow/tools/ci_build/linux/libtensorflow.sh +++ b/tensorflow/tools/ci_build/linux/libtensorflow.sh @@ -14,9 +14,8 @@ # limitations under the License. # ============================================================================== # -# Script to produce a tarball release of the C-library and associated C API -# header file. Intended to be run inside a docker container. See -# libtensorflow_docker.sh +# Script to produce binary releases for libtensorflow (C API, Java jars etc.). +# Intended to be run inside a docker container. See libtensorflow_docker.sh set -ex SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh index c300c4670fd..4bf34dd2993 100755 --- a/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh @@ -14,8 +14,7 @@ # limitations under the License. # ============================================================================== # -# Script to build a binary release tarball for the TensorFlow C-library without -# GPU support. +# Script to build a binary releases of libtensorflow without GPU support. set -ex SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh index 5423831caad..dcda8228bc2 100755 --- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh +++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh @@ -14,9 +14,9 @@ # limitations under the License. # ============================================================================== # -# Script to produce a tarball release of the C-library and associated C API -# header file. Builds a docker container and then builds the C-library in -# said container. +# Script to produce a tarball release of the C-library, Java native library +# and Java .jars. +# Builds a docker container and then builds in said container. # # See libtensorflow_cpu.sh and libtensorflow_gpu.sh @@ -29,9 +29,9 @@ DOCKER_IMAGE="tf-libtensorflow-cpu" DOCKER_FILE="Dockerfile.cpu" DOCKER_BINARY="docker" if [ "${TF_NEED_CUDA}" == "1" ]; then - DOCKER_IMAGE="tf-tensorflow-gpu" - DOCKER_BINARY="nvidia-docker" - DOCKER_FILE="Dockerfile.gpu" + DOCKER_IMAGE="tf-tensorflow-gpu" + DOCKER_BINARY="nvidia-docker" + DOCKER_FILE="Dockerfile.gpu" fi docker build \ diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_gpu.sh b/tensorflow/tools/ci_build/linux/libtensorflow_gpu.sh index e13098c220d..6dca0c37c87 100755 --- a/tensorflow/tools/ci_build/linux/libtensorflow_gpu.sh +++ b/tensorflow/tools/ci_build/linux/libtensorflow_gpu.sh @@ -14,8 +14,7 @@ # limitations under the License. # ============================================================================== # -# Script to build a binary release tarball for the TensorFlow C-library without -# GPU support. +# Script to build a binary releases of libtensorflow with GPU support. set -ex SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh new file mode 100755 index 00000000000..e5f4a22f7ad --- /dev/null +++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(sysctl -n hw.ncpu) +N_JOBS=$((N_JOBS+1)) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_GCP=0 +export TF_NEED_HDFS=0 +export TF_NEED_CUDA=0 +export PYTHON_BIN_PATH=$(which python2) +yes "" | ./configure +which bazel +bazel test --test_tag_filters=-gpu,-benchmark-test,-nomac \ + --test_timeout 300,450,1200,3600 \ + --test_size_filters=small,medium \ + --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... \ + -//tensorflow/tensorboard/... diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh index 432201156d7..d90a1b905d9 100755 --- a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh @@ -14,9 +14,7 @@ # limitations under the License. # ============================================================================== # -# Script to produce a tarball release of the C-library and associated C API -# header file. -# Produces: lib_package/libtensorflow-gpu-darwin-x86_64.tar.gz +# Script to produce binary release of libtensorflow (C API, Java jars etc.). set -ex SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -30,6 +28,7 @@ export TF_NEED_GCP=0 export TF_NEED_HDFS=0 export TF_NEED_CUDA=0 export TF_NEED_OPENCL=0 +export TF_NEED_MKL=0 export COMPUTECPP_PATH="/usr/local" export PATH="/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin" diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh b/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh index 5e6f4b9fc2d..79973647c11 100755 --- a/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh +++ b/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh @@ -14,9 +14,7 @@ # limitations under the License. # ============================================================================== # -# Script to produce a tarball release of the C-library and associated C API -# header file. -# Produces: lib_package/libtensorflow-gpu-darwin-x86_64.tar.gz +# Script to produce binary release of libtensorflow (C API, Java jars etc.). set -ex SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -31,6 +29,7 @@ export PYTHON_BIN_PATH="/usr/bin/python" export TF_NEED_GCP=0 export TF_NEED_HDFS=0 export TF_NEED_OPENCL=0 +export TF_NEED_MKL=0 export COMPUTECPP_PATH="/usr/local" export PATH="/usr/local/cuda/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin" diff --git a/tensorflow/tools/ci_build/protobuf/protobuf_optimized_pip.sh b/tensorflow/tools/ci_build/protobuf/protobuf_optimized_pip.sh index 59ba71f5df7..3e31aa1ce10 100755 --- a/tensorflow/tools/ci_build/protobuf/protobuf_optimized_pip.sh +++ b/tensorflow/tools/ci_build/protobuf/protobuf_optimized_pip.sh @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -PROTOBUF_VERSION="3.2.0" +PROTOBUF_VERSION="3.3.1" PYTHON_BIN=${PYTHON_BIN:-python} DIR=${PWD}/protobuf diff --git a/tensorflow/tools/ci_build/pylintrc b/tensorflow/tools/ci_build/pylintrc index 0779ed91bc3..e71017e621c 100644 --- a/tensorflow/tools/ci_build/pylintrc +++ b/tensorflow/tools/ci_build/pylintrc @@ -38,7 +38,7 @@ enable=indexing-exception,old-raise-syntax # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" -disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable +disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager # Set the cache size for astng objects. @@ -322,4 +322,4 @@ indent-after-paren=4 [GOOGLE LINES] # Regexp for a proper copyright notice. -copyright=Copyright \d{4} The TensorFlow Authors\. +All [Rr]ights [Rr]eserved\. \ No newline at end of file +copyright=Copyright \d{4} The TensorFlow Authors\. +All [Rr]ights [Rr]eserved\. diff --git a/tensorflow/tools/ci_build/update_version.sh b/tensorflow/tools/ci_build/update_version.sh index 759c7e5f7e4..682f5329f58 100755 --- a/tensorflow/tools/ci_build/update_version.sh +++ b/tensorflow/tools/ci_build/update_version.sh @@ -61,7 +61,9 @@ fi MAJOR=$(echo "${NEW_VER}" | cut -d \. -f 1) MINOR=$(echo "${NEW_VER}" | cut -d \. -f 2) PATCH=$(echo "${NEW_VER}" | cut -d \. -f 3) +PATCH_NUM=$(echo "$PATCH" | cut -d \- -f 1) PIP_PATCH="${PATCH//-}" +SUFFIX=$(echo $NEW_VER | sed "s/${MAJOR}.${MINOR}.${PATCH%-*}//g") # Update tensorflow/core/public/version.h VERSION_H="${TF_SRC_DIR}/core/public/version.h" @@ -71,13 +73,17 @@ OLD_MAJOR=$(cat ${VERSION_H} | grep -E "^#define TF_MAJOR_VERSION [0-9]+" | \ cut -d ' ' -f 3) OLD_MINOR=$(cat ${VERSION_H} | grep -E "^#define TF_MINOR_VERSION [0-9]+" | \ cut -d ' ' -f 3) -OLD_PATCH=$(cat ${VERSION_H} | grep -E "^#define TF_PATCH_VERSION [[:alnum:]-]+" | \ +OLD_PATCH_NUM=$(cat ${VERSION_H} | grep -E "^#define TF_PATCH_VERSION [[:alnum:]-]+" | \ cut -d ' ' -f 3) +OLD_EXTENSION=$(cat ${VERSION_H} | grep -E "^#define TF_VERSION_SUFFIX \"[[:alnum:]-]+\"" | \ +cut -d ' ' -f 3) +OLD_PATCH="$OLD_PATCH_NUM${OLD_EXTENSION//\"}" +OLD_PIP_PATCH="${OLD_PATCH//-}" sed -i -e "s/^#define TF_MAJOR_VERSION ${OLD_MAJOR}/#define TF_MAJOR_VERSION ${MAJOR}/g" ${VERSION_H} sed -i -e "s/^#define TF_MINOR_VERSION ${OLD_MINOR}/#define TF_MINOR_VERSION ${MINOR}/g" ${VERSION_H} -sed -i -e "s/^#define TF_PATCH_VERSION ${OLD_PATCH}/#define TF_PATCH_VERSION ${PATCH}/g" "${VERSION_H}" - +sed -i -e "s/^#define TF_PATCH_VERSION ${OLD_PATCH}/#define TF_PATCH_VERSION ${PATCH_NUM}/g" "${VERSION_H}" +sed -i -e "s/^#define TF_VERSION_SUFFIX \".*\"/#define TF_VERSION_SUFFIX \"${SUFFIX}\"/g" "${VERSION_H}" # Update setup.py SETUP_PY="${TF_SRC_DIR}/tools/pip_package/setup.py" @@ -85,23 +91,6 @@ check_existence file "${SETUP_PY}" sed -i -e "s/^\_VERSION = [\'\"].*[\'\"]/\_VERSION = \'${MAJOR}.${MINOR}.${PATCH}\'/g" "${SETUP_PY}" -# Update cmake setup.py -CMAKE_SETUP_PY="${TF_SRC_DIR}/contrib/cmake/setup.py" -check_existence file "${CMAKE_SETUP_PY}" - -sed -i -e "s/^\_VERSION = [\'\"].*-cmake-experimental[\'\"]/\_VERSION = \'${MAJOR}.${MINOR}.${PATCH}-cmake-experimental\'/g" "${CMAKE_SETUP_PY}" - - -# Update os_setup.md -OS_SETUP="${TF_SRC_DIR}/g3doc/get_started/os_setup.md" -check_existence file "${OS_SETUP}" - -sed -i -r -e "s/(.*pip[0-9]* install .*tensorflow-)([0-9]+\.[0-9]+\.[[:alnum:]]+)(-.*\.whl)/\1${MAJOR}.${MINOR}.${PIP_PATCH}\3/g" "${OS_SETUP}" -sed -i -r -e "s/(.*pip[0-9]* install .*tensorflow_gpu-)([0-9]+\.[0-9]+\.[[:alnum:]]+)(-.*\.whl)/\1${MAJOR}.${MINOR}.${PIP_PATCH}\3/g" "${OS_SETUP}" -sed -i -r -e "s/(.*export TF_BINARY_URL.*tensorflow-)([0-9]+\.[0-9]+\.[[:alnum:]]+)(-.*\.whl)/\1${MAJOR}.${MINOR}.${PIP_PATCH}\3/g" "${OS_SETUP}" -sed -i -r -e "s/(.*export TF_BINARY_URL.*tensorflow_gpu-)([0-9]+\.[0-9]+\.[[:alnum:]]+)(-.*\.whl)/\1${MAJOR}.${MINOR}.${PIP_PATCH}\3/g" "${OS_SETUP}" -sed -i -r -e "s/(.*\`)([0-9]+\.[0-9]+\.[[:alnum:]-]+)(-gpu.*)/\1${MAJOR}.${MINOR}.${PATCH}\3/g" "${OS_SETUP}" - # Update README.md README_MD="./README.md" @@ -109,6 +98,26 @@ check_existence file "${README_MD}" sed -i -r -e "s/${OLD_MAJOR}\.${OLD_MINOR}\.([[:alnum:]]+)-/${MAJOR}.${MINOR}.${PIP_PATCH}-/g" "${README_MD}" +# Update the install md files +NEW_PIP_TAG=$MAJOR.$MINOR.$PIP_PATCH +OLD_PIP_TAG=$OLD_MAJOR.$OLD_MINOR.$OLD_PIP_PATCH + +for file in ${TF_SRC_DIR}/docs_src/install/install_{linux,mac,windows,sources}.md +do + sed -i "s/tensorflow-${OLD_PIP_TAG}/tensorflow-${NEW_PIP_TAG}/g" $file + sed -i "s/tensorflow_gpu-${OLD_PIP_TAG}/tensorflow_gpu-${NEW_PIP_TAG}/g" $file + sed -i "s/TensorFlow ${OLD_PIP_TAG}/TensorFlow ${NEW_PIP_TAG}/g" $file +done + +NEW_TAG=$MAJOR.$MINOR.$PATCH +OLD_TAG=$OLD_MAJOR.$OLD_MINOR.$OLD_PATCH + +for file in ${TF_SRC_DIR}/docs_src/install/install_{java,go,c}.md +do + sed -i "s/x86_64-${OLD_TAG}/x86_64-${NEW_TAG}/g" $file + sed -i "s/libtensorflow-${OLD_TAG}.jar/libtensorflow-${NEW_TAG}.jar/g" $file + sed -i "s/${OLD_TAG}<\/version>/${NEW_TAG}<\/version>/g" $file +done # Updates to be made if there are major / minor version changes MAJOR_MINOR_CHANGE=0 diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh index 4aae0378a8d..dff4707cbef 100644 --- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh +++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh @@ -35,7 +35,6 @@ failing_cpu_cc_tests="\ " broken_cpu_cc_tests="\ - //tensorflow/core/kernels/hexagon:graph_transferer_test + \ //tensorflow/cc:framework_cc_ops_test + \ //tensorflow/core/platform/cloud:time_util_test + \ //tensorflow/core/platform/cloud:oauth_client_test + \ @@ -43,7 +42,9 @@ broken_cpu_cc_tests="\ //tensorflow/core/platform/cloud:google_auth_provider_test + \ //tensorflow/core/platform/cloud:gcs_file_system_test + \ //tensorflow/core/kernels/cloud:bigquery_table_accessor_test + \ + //tensorflow/core/kernels/hexagon:graph_transferer_test + \ //tensorflow/core/kernels/hexagon:quantized_matmul_op_for_hexagon_test + \ + //tensorflow/core/kernels:remote_fused_graph_execute_utils_test + \ //tensorflow/core/kernels:requantize_op_test + \ //tensorflow/core/kernels:requantization_range_op_test + \ //tensorflow/core/kernels:quantized_reshape_op_test + \ @@ -95,65 +96,6 @@ exclude_cpu_cc_tests="${failing_cpu_cc_tests} + ${broken_cpu_cc_tests}" exclude_gpu_cc_tests="${extra_failing_gpu_cc_tests} + ${exclude_cpu_cc_tests}" -# Python tests -# The first argument is the name of the python test direcotry -function get_failing_cpu_py_tests() { - echo " - //$1/tensorflow/python:basic_session_run_hooks_test + \ - //$1/tensorflow/python:bigquery_reader_ops_test + \ - //$1/tensorflow/python:contrib_test + \ - //$1/tensorflow/python:dequantize_op_test + \ - //$1/tensorflow/python:directory_watcher_test + \ - //$1/tensorflow/python:event_multiplexer_test + \ - //$1/tensorflow/python:file_io_test + \ - //$1/tensorflow/python:file_system_test + \ - //$1/tensorflow/python:framework_meta_graph_test + \ - //$1/tensorflow/python:framework_ops_test + \ - //$1/tensorflow/python:framework_tensor_util_test + \ - //$1/tensorflow/python:framework_test_util_test + \ - //$1/tensorflow/python:gradients_test + \ - //$1/tensorflow/python:image_ops_test + \ - //$1/tensorflow/python:localhost_cluster_performance_test + \ - //$1/tensorflow/python:monitored_session_test + \ - //$1/tensorflow/python:nn_batchnorm_test + \ - //$1/tensorflow/python:protobuf_compare_test + \ - //$1/tensorflow/python:quantized_conv_ops_test + \ - //$1/tensorflow/python:saver_large_variable_test + \ - //$1/tensorflow/python:saver_test + \ - //$1/tensorflow/python:session_test + \ - //$1/tensorflow/python:supervisor_test + \ - //$1/tensorflow/python:sync_replicas_optimizer_test + \ - //$1/tensorflow/python/debug:curses_ui_test + \ - //$1/tensorflow/python/kernel_tests:as_string_op_test + \ - //$1/tensorflow/python/kernel_tests:benchmark_test + \ - //$1/tensorflow/python/kernel_tests:cast_op_test + \ - //$1/tensorflow/python/kernel_tests:clip_ops_test + \ - //$1/tensorflow/python/kernel_tests:conv_ops_test + \ - //$1/tensorflow/python/kernel_tests:decode_image_op_test + \ - //$1/tensorflow/python/kernel_tests:depthwise_conv_op_test + \ - //$1/tensorflow/python/kernel_tests:functional_ops_test + \ - //$1/tensorflow/python/kernel_tests:py_func_test + \ - //$1/tensorflow/python/kernel_tests:rnn_test + \ - //$1/tensorflow/python/kernel_tests:sets_test + \ - //$1/tensorflow/python/kernel_tests:sparse_matmul_op_test + \ - //$1/tensorflow/python/kernel_tests:string_to_number_op_test + \ - //$1/tensorflow/python/kernel_tests:summary_ops_test + \ - //$1/tensorflow/python/kernel_tests:variable_scope_test + \ - //$1/tensorflow/python/saved_model:saved_model_test \ - " -} - -function get_failing_gpu_py_tests() { - echo " - //$1/tensorflow/python/kernel_tests:diag_op_test + \ - //$1/tensorflow/python/kernel_tests:one_hot_op_test + \ - //$1/tensorflow/python/kernel_tests:rnn_test + \ - //$1/tensorflow/python/kernel_tests:sets_test + \ - //$1/tensorflow/python/kernel_tests:trace_op_test + \ - $(get_failing_cpu_py_tests $1) - " -} - function clean_output_base() { # TODO(pcloudy): bazel clean --expunge doesn't work on Windows yet. # Clean the output base manually to ensure build correctness @@ -177,6 +119,13 @@ function run_configure_for_cpu_build { if [ -z "$CC_OPT_FLAGS" ]; then export CC_OPT_FLAGS="-march=native" fi + if [ -z "$TF_NEED_MKL" ]; then + export TF_NEED_MKL=0 + fi + export TF_NEED_VERBS=0 + export TF_NEED_GCP=0 + export TF_NEED_HDFS=0 + export TF_NEED_OPENCL=0 echo "" | ./configure } @@ -196,6 +145,11 @@ function run_configure_for_gpu_build { if [ -z "$CC_OPT_FLAGS" ]; then export CC_OPT_FLAGS="-march=native" fi + export TF_NEED_VERBS=0 + export TF_NEED_MKL=0 + export TF_NEED_GCP=0 + export TF_NEED_HDFS=0 + export TF_NEED_OPENCL=0 echo "" | ./configure } @@ -207,5 +161,5 @@ function create_python_test_dir() { function reinstall_tensorflow_pip() { echo "y" | pip uninstall tensorflow -q || true - pip install ${1} + pip install ${1} --no-deps } diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh index 662de93c16b..8853dc53b17 100644 --- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh +++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh @@ -30,10 +30,11 @@ export TMPDIR="C:/tmp" mkdir -p "$TMPDIR" # Set bash path -export BAZEL_SH="C:/tools/msys64/usr/bin/bash" +export BAZEL_SH=${BAZEL_SH:-"C:/tools/msys64/usr/bin/bash"} # Set Python path for ./configure export PYTHON_BIN_PATH="C:/Program Files/Anaconda3/python" +export PYTHON_LIB_PATH="C:/Program Files/Anaconda3/lib/site-packages" # Set Python path for cc_configure.bzl export BAZEL_PYTHON="C:/Program Files/Anaconda3/python" @@ -54,4 +55,4 @@ export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v8.0/extras/CUPT export PATH="/c/tools/cuda/bin:$PATH" # Set the common build options on Windows -export BUILD_OPTS='--cpu=x64_windows_msvc --host_cpu=x64_windows_msvc --copt=/w --verbose_failures --experimental_ui' +export BUILD_OPTS='--copt=-w --host_copt=-w --verbose_failures --experimental_ui' diff --git a/tensorflow/tools/ci_build/windows/cpu/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/cpu/bazel/common_env.sh deleted file mode 100644 index 6e7e555065a..00000000000 --- a/tensorflow/tools/ci_build/windows/cpu/bazel/common_env.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# -# This script assumes the standard setup on tensorflow Jenkins windows machines. -# It is NOT guaranteed to work on any other machine. Use at your own risk! -# -# REQUIREMENTS: -# * All installed in standard locations: -# - JDK8, and JAVA_HOME set. -# - Microsoft Visual Studio 2015 Community Edition -# - Msys2 -# - Anaconda3 -# * Bazel windows executable copied as "bazel.exe" and included in PATH. - -# All commands shall pass, and all should be visible. -set -x -set -e - -# Use a temporary directory with a short name. -export TMPDIR="C:/tmp" -mkdir -p "$TMPDIR" - -# Set bash path -export BAZEL_SH="C:/tools/msys64/usr/bin/bash" - -# Set Python path for ./configure -export PYTHON_BIN_PATH="C:/Program Files/Anaconda3/python" - -# Set Python path for cc_configure.bzl -export BAZEL_PYTHON="C:/Program Files/Anaconda3/python" - -# Set Visual Studio path -export BAZEL_VS="C:/Program Files (x86)/Microsoft Visual Studio 14.0" - -# Add python into PATH, it's needed because gen_git_source.py uses -# '/usr/bin/env python' as a shebang -export PATH="/c/Program Files/Anaconda3:$PATH" diff --git a/tensorflow/tools/ci_build/windows/cpu/bazel/run_libtensorflow.bat b/tensorflow/tools/ci_build/windows/cpu/bazel/run_libtensorflow.bat new file mode 100644 index 00000000000..6a88b1865a4 --- /dev/null +++ b/tensorflow/tools/ci_build/windows/cpu/bazel/run_libtensorflow.bat @@ -0,0 +1 @@ +c:\tools\msys64\usr\bin\bash -l %cd%/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh %* diff --git a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat index 62e97f3f071..07ad70dd344 100644 --- a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat @@ -22,11 +22,13 @@ CALL "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\vcvarsall.bat" :: Turn echo back on, above script turns it off. ECHO ON -:: Some common variables to be shared between runs. -SET CMAKE_EXE="C:\Program Files\cmake\bin\cmake.exe" -SET SWIG_EXE="C:\swigwin-3.0.10\swig.exe" -SET PY_EXE="C:\Program Files\Anaconda3\python.exe" -SET PY_LIB="C:\Program Files\Anaconda3\libs\python35.lib" +:: Set environment variables to be shared between runs. Do not override if they +:: are set already. + +IF DEFINED CMAKE_EXE (ECHO CMAKE_EXE is set to %CMAKE_EXE%) ELSE (SET CMAKE_EXE="C:\Program Files\cmake\bin\cmake.exe") +IF DEFINED SWIG_EXE (ECHO SWIG_EXE is set to %SWIG_EXE%) ELSE (SET SWIG_EXE="C:\swigwin-3.0.10\swig.exe") +IF DEFINED PY_EXE (ECHO PY_EXE is set to %PY_EXE%) ELSE (SET PY_EXE="C:\Program Files\Anaconda3\python.exe") +IF DEFINED PY_LIB (ECHO PY_LIB is set to %PY_LIB%) ELSE (SET PY_LIB="C:\Program Files\Anaconda3\libs\python35.lib") SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe" diff --git a/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat b/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat index 9908762bca8..96fbadd1767 100644 --- a/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat +++ b/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat @@ -22,15 +22,13 @@ CD %BUILD_DIR% SET BUILD_CC_TESTS=OFF SET BUILD_PYTHON_TESTS=ON -SET PIP_EXE="C:\Program Files\Anaconda3\Scripts\pip.exe" +:: Set pip binary location. Do not override if it is set already. +IF DEFINED PIP_EXE (ECHO PIP_EXE is set to %PIP_EXE%) ELSE (SET PIP_EXE="C:\Program Files\Anaconda3\Scripts\pip.exe") :: Run the CMAKE build to build the pip package. CALL %REPO_ROOT%\tensorflow\tools\ci_build\windows\cpu\cmake\run_build.bat if %errorlevel% neq 0 exit /b %errorlevel% -:: Attempt to upgrade PIP to work around Anaconda issue #542. -%PIP_EXE% install --ignore-installed --upgrade pip setuptools -v -v - :: Since there are no wildcards in windows command prompt, use dark magic to get the wheel file name. DIR %REPO_ROOT%\%BUILD_DIR%\tf_python\dist\ /S /B > wheel_filename_file set /p WHEEL_FILENAME=&2; exit 1; } -clean_output_base - run_configure_for_cpu_build +clean_output_base + bazel build -c opt $BUILD_OPTS tensorflow/tools/pip_package:build_pip_package || exit $? # Create a python test directory to avoid package name conflict @@ -58,12 +58,10 @@ create_python_test_dir "${PY_TEST_DIR}" PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl) reinstall_tensorflow_pip ${PIP_NAME} -failing_cpu_py_tests=$(get_failing_cpu_py_tests ${PY_TEST_DIR}) - -passing_tests=$(bazel query "kind(py_test, //${PY_TEST_DIR}/tensorflow/python/...) - (${failing_cpu_py_tests})" | - # We need to strip \r so that the result could be store into a variable under MSYS - tr '\r' ' ') - # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore, # which will result testing system installed tensorflow -bazel test -c opt $BUILD_OPTS -k $passing_tests --define=no_tensorflow_py_deps=true --test_output=errors +bazel test -c opt $BUILD_OPTS -k --test_output=errors \ + --define=no_tensorflow_py_deps=true --test_lang_filters=py \ + --test_tag_filters=-no_pip,-no_windows \ + --build_tag_filters=-no_pip,-no_windows --build_tests_only \ + //${PY_TEST_DIR}/tensorflow/python/... diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat index f124012edcb..b4f9cc84762 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat @@ -22,12 +22,14 @@ CALL "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\vcvarsall.bat" :: Turn echo back on, above script turns it off. ECHO ON -:: Some common variables to be shared between runs. -SET CMAKE_EXE="C:\Program Files\cmake\bin\cmake.exe" -SET SWIG_EXE="C:\swigwin-3.0.10\swig.exe" -SET PY_EXE="C:\Program Files\Anaconda3\python.exe" -SET PY_LIB="C:\Program Files\Anaconda3\libs\python35.lib" -SET CUDNN_HOME="c:\tools\cuda" +:: Set environment variables to be shared between runs. Do not override if they +:: are set already. + +IF DEFINED CMAKE_EXE (ECHO CMAKE_EXE is set to %CMAKE_EXE%) ELSE (SET CMAKE_EXE="C:\Program Files\cmake\bin\cmake.exe") +IF DEFINED SWIG_EXE (ECHO SWIG_EXE is set to %SWIG_EXE%) ELSE (SET SWIG_EXE="C:\swigwin-3.0.10\swig.exe") +IF DEFINED PY_EXE (ECHO PY_EXE is set to %PY_EXE%) ELSE (SET PY_EXE="C:\Program Files\Anaconda3\python.exe") +IF DEFINED PY_LIB (ECHO PY_LIB is set to %PY_LIB%) ELSE (SET PY_LIB="C:\Program Files\Anaconda3\libs\python35.lib") +IF DEFINED CUDNN_HOME (ECHO CUDNN_HOME is set to %CUDNN_HOME%) ELSE (SET CUDNN_HOME="c:\tools\cuda") SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe" diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat index 9307ebb66ba..e774a6e9168 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat @@ -22,15 +22,12 @@ CD %BUILD_DIR% SET BUILD_CC_TESTS=OFF SET BUILD_PYTHON_TESTS=ON -SET PIP_EXE="C:\Program Files\Anaconda3\Scripts\pip.exe" +IF DEFINED PIP_EXE (ECHO PIP_EXE is set to %PIP_EXE%) ELSE (SET PIP_EXE="C:\Program Files\Anaconda3\Scripts\pip.exe") :: Run the CMAKE build to build the pip package. CALL %REPO_ROOT%\tensorflow\tools\ci_build\windows\gpu\cmake\run_build.bat if %errorlevel% neq 0 exit /b %errorlevel% -:: Attempt to upgrade PIP to work around Anaconda issue #542. -%PIP_EXE% install --ignore-installed --upgrade pip setuptools -v -v - :: Since there are no wildcards in windows command prompt, use dark magic to get the wheel file name. DIR %REPO_ROOT%\%BUILD_DIR%\tf_python\dist\ /S /B > wheel_filename_file set /p WHEEL_FILENAME=&2; exit 1; } -clean_output_base - run_configure_for_gpu_build +clean_output_base + bazel build -c opt --config=win-cuda $BUILD_OPTS tensorflow/tools/pip_package:build_pip_package || exit $? # Create a python test directory to avoid package name conflict @@ -58,13 +58,11 @@ create_python_test_dir "${PY_TEST_DIR}" PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl) reinstall_tensorflow_pip ${PIP_NAME} -failing_gpu_py_tests=$(get_failing_gpu_py_tests ${PY_TEST_DIR}) - -passing_tests=$(bazel query "kind(py_test, //${PY_TEST_DIR}/tensorflow/python/...) - (${failing_gpu_py_tests})" | - # We need to strip \r so that the result could be store into a variable under MSYS - tr '\r' ' ') - # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore, # which will result testing system installed tensorflow -# GPU tests are very flaky when running concurently, so set local_test_jobs=5 -bazel test -c opt --config=win-cuda $BUILD_OPTS -k $passing_tests --define=no_tensorflow_py_deps=true --test_output=errors --local_test_jobs=5 +# GPU tests are very flaky when running concurently, so set local_test_jobs=1 +bazel test -c opt --config=win-cuda $BUILD_OPTS -k --test_output=errors \ + --define=no_tensorflow_py_deps=true --test_lang_filters=py \ + --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu \ + --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu \ + --local_test_jobs=1 --build_tests_only //${PY_TEST_DIR}/tensorflow/python/... diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh new file mode 100755 index 00000000000..9ac3613f27e --- /dev/null +++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Script to produce binary release of libtensorflow (C API, Java jars etc.). + +set -ex +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Setup environment for bazel builds +source "${SCRIPT_DIR}/bazel/common_env.sh" +source "${SCRIPT_DIR}/bazel/bazel_test_lib.sh" + +# Sanity check that this is being run from the root of the git repository. +cd ${SCRIPT_DIR}/../../../.. +if [ ! -e "WORKSPACE" ]; then + echo "Must run this from the root of the bazel workspace" + echo "Currently at ${PWD}, script is at ${SCRIPT_DIR}" + exit 1 +fi + +# Enable JNI support for Windows in Bazel. +# This can be removed once +# https://github.com/bazelbuild/bazel/pull/2599 +# has been merged and we switch to a bazel release containing it. +cp "${JAVA_HOME}/include/win32/jni_md.h" "./tensorflow/java/src/main/native/windows_jni_md.h" +sed -i -e "s|@bazel_tools//tools/jdk:jni_md_header-linux|windows_jni_md.h|" ./tensorflow/java/src/main/native/BUILD +#### END HACKS TO BE RESOLVED WITH NEW BAZEL VERSIONS #### + +export TF_BAZEL_TARGETS="//tensorflow:libtensorflow.so" +export TF_BAZEL_TARGETS="${TF_BAZEL_TARGETS} //tensorflow/tools/lib_package:clicenses_generate" +export TF_BAZEL_TARGETS="${TF_BAZEL_TARGETS} //tensorflow/java:libtensorflow_jni.so" +export TF_BAZEL_TARGETS="${TF_BAZEL_TARGETS} //tensorflow/tools/lib_package:jnilicenses_generate" + +clean_output_base +run_configure_for_cpu_build + +# build_libtensorflow_tarball in ../builds/libtensorflow.sh +# cannot be used on Windows since it relies on pkg_tar rules. +# So we do something special here +bazel build -c opt ${BUILD_OPTS} \ + tensorflow:libtensorflow.so \ + tensorflow/tools/lib_package:clicenses_generate \ + tensorflow/java:libtensorflow_jni.so \ + tensorflow/tools/lib_package:jnilicenses_generate + +# Revert the hacks above +git checkout ./tensorflow/tools/pip_package/BUILD +git checkout ./tensorflow/java/src/main/native/BUILD +rm -f ./tensorflow/java/src/main/native/windows_jni_md.h + +DIR=lib_package +rm -rf ${DIR} +mkdir -p ${DIR} + +# Zip up the .dll and the LICENSE for the JNI library. +cp bazel-bin/tensorflow/java/libtensorflow_jni.so ${DIR}/tensorflow_jni.dll +zip -j ${DIR}/libtensorflow_jni-cpu-windows-$(uname -m).zip \ + ${DIR}/tensorflow_jni.dll \ + bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/jni/LICENSE +rm -f ${DIR}/tensorflow_jni.dll + +# Zip up the .dll, LICENSE and include files for the C library. +mkdir -p ${DIR}/include/tensorflow/c +mkdir -p ${DIR}/lib +cp bazel-bin/tensorflow/libtensorflow.so ${DIR}/lib/tensorflow.dll +cp tensorflow/c/c_api.h ${DIR}/include/tensorflow/c +cp bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE ${DIR}/include/tensorflow/c +cd ${DIR} +zip -j libtensorflow-cpu-windows-$(uname -m).zip \ + lib/tensorflow.dll \ + include/tensorflow/c/c_api.h \ + include/tensorflow/c/LICENSE +rm -rf lib include diff --git a/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh new file mode 100755 index 00000000000..11064130713 --- /dev/null +++ b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_GCP=0 +export TF_NEED_HDFS=0 +export PYTHON_BIN_PATH=`which python3` + +export TF_NEED_CUDA=1 +export TF_ENABLE_XLA=1 +export TF_CUDA_COMPUTE_CAPABILITIES=3.7 + +yes "" | ./configure + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --build_tests_only --test_output=errors --local_test_jobs=8 \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ + //tensorflow/compiler/... diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index 96ae9583d73..f92edd0dd88 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -9,6 +9,8 @@ package( default_visibility = ["//tensorflow:__subpackages__"], ) +load("//tensorflow:tensorflow.bzl", "py_test") + py_library( name = "public_api", srcs = ["public_api.py"], @@ -17,6 +19,7 @@ py_library( py_test( name = "public_api_test", + size = "small", srcs = ["public_api_test.py"], srcs_version = "PY2AND3", deps = [ @@ -33,6 +36,7 @@ py_library( py_test( name = "traverse_test", + size = "small", srcs = ["traverse_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py index 4c1ccebd616..e0acead9195 100644 --- a/tensorflow/tools/common/public_api.py +++ b/tensorflow/tools/common/public_api.py @@ -18,7 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect +import re + +from tensorflow.python.util import tf_inspect class PublicAPIVisitor(object): @@ -34,24 +36,52 @@ class PublicAPIVisitor(object): visitor: A visitor to call for the public API. """ self._visitor = visitor + self._root_name = 'tf' - # Modules/classes we do not want to descend into if we hit them. Usually, - # sytem modules exposed through platforms for compatibility reasons. - # Each entry maps a module path to a name to ignore in traversal. - _do_not_descend_map = { - # TODO(drpng): This can be removed once sealed off. - '': ['platform', 'pywrap_tensorflow', 'user_ops', 'python'], + # Modules/classes we want to suppress entirely. + self._private_map = { + # Some implementations have this internal module that we shouldn't + # expose. + 'tf.flags': ['cpp_flags'], + } - # Exclude protos, they leak a lot. - 'core': ['protobuf'], + # Modules/classes we do not want to descend into if we hit them. Usually, + # system modules exposed through platforms for compatibility reasons. + # Each entry maps a module path to a name to ignore in traversal. + self._do_not_descend_map = { + 'tf': [ + 'core', + 'examples', + 'flags', # Don't add flags + # TODO(drpng): This can be removed once sealed off. + 'platform', + # TODO(drpng): This can be removed once sealed. + 'pywrap_tensorflow', + # TODO(drpng): This can be removed once sealed. + 'user_ops', + 'python', + 'tools', + 'tensorboard', + ], - # Some implementations have this internal module that we shouldn't expose. - 'flags': ['cpp_flags'], + ## Everything below here is legitimate. + # It'll stay, but it's not officially part of the API. + 'tf.app': ['flags'], + # Imported for compatibility between py2/3. + 'tf.test': ['mock'], + } - # Everything below here is legitimate. - 'app': ['flags'], # It'll stay, but it's not officially part of the API. - 'test': ['mock'], # Imported for compatibility between py2/3. - } + @property + def private_map(self): + """A map from parents to symbols that should not be included at all. + + This map can be edited, but it should not be edited once traversal has + begun. + + Returns: + The map marking symbols to not include. + """ + return self._private_map @property def do_not_descend_map(self): @@ -65,10 +95,17 @@ class PublicAPIVisitor(object): """ return self._do_not_descend_map - def _isprivate(self, name): + def set_root_name(self, root_name): + """Override the default root name of 'tf'.""" + self._root_name = root_name + + def _is_private(self, path, name): """Return whether a name is private.""" - # TODO(wicke): We have to almost certainly add more exceptions than init. - return name.startswith('_') and name not in ['__init__'] + # TODO(wicke): Find out what names to exclude. + return ((path in self._private_map and + name in self._private_map[path]) or + (name.startswith('_') and not re.match('__.*__$', name) or + name in ['__base__', '__class__'])) def _do_not_descend(self, path, name): """Safely queries if a specific fully qualified name should be excluded.""" @@ -79,18 +116,22 @@ class PublicAPIVisitor(object): """Visitor interface, see `traverse` for details.""" # Avoid long waits in cases of pretty unambiguous failure. - if inspect.ismodule(parent) and len(path.split('.')) > 10: - raise RuntimeError('Modules nested too deep:\n%s\n\nThis is likely a ' - 'problem with an accidental public import.' % path) + if tf_inspect.ismodule(parent) and len(path.split('.')) > 10: + raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a ' + 'problem with an accidental public import.' % + (self._root_name, path)) + + # Includes self._root_name + full_path = '.'.join([self._root_name, path]) if path else self._root_name # Remove things that are not visible. for name, child in list(children): - if self._isprivate(name): + if self._is_private(full_path, name): children.remove((name, child)) self._visitor(path, parent, children) # Remove things that are visible, but which should not be descended into. for name, child in list(children): - if self._do_not_descend(path, name): + if self._do_not_descend(full_path, name): children.remove((name, child)) diff --git a/tensorflow/tools/common/traverse.py b/tensorflow/tools/common/traverse.py index 443838d9682..9607f80686d 100644 --- a/tensorflow/tools/common/traverse.py +++ b/tensorflow/tools/common/traverse.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect import sys +from tensorflow.python.util import tf_inspect __all__ = ['traverse'] @@ -29,11 +29,11 @@ def _traverse_internal(root, visit, stack, path): """Internal helper for traverse.""" # Only traverse modules and classes - if not inspect.isclass(root) and not inspect.ismodule(root): + if not tf_inspect.isclass(root) and not tf_inspect.ismodule(root): return try: - children = inspect.getmembers(root) + children = tf_inspect.getmembers(root) except ImportError: # On some Python installations, some modules do not support enumerating # members (six in particular), leading to import errors. @@ -43,7 +43,8 @@ def _traverse_internal(root, visit, stack, path): visit(path, root, children) for name, child in children: # Do not descend into built-in modules - if inspect.ismodule(child) and child.__name__ in sys.builtin_module_names: + if tf_inspect.ismodule( + child) and child.__name__ in sys.builtin_module_names: continue # Break cycles @@ -72,8 +73,8 @@ def traverse(root, visit): never descends into built-in modules. `children`, a list of `(name, object)` pairs are determined by - `inspect.getmembers`. To avoid visiting parts of the tree, `children` can be - modified in place, using `del` or slice assignment. + `tf_inspect.getmembers`. To avoid visiting parts of the tree, `children` can + be modified in place, using `del` or slice assignment. Cycles (determined by reference equality, `is`) stop the traversal. A stack of objects is kept to find cycles. Objects forming cycles may appear in diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD index 0f3de10a0ad..fb40cf0833f 100644 --- a/tensorflow/tools/compatibility/BUILD +++ b/tensorflow/tools/compatibility/BUILD @@ -10,12 +10,16 @@ load( py_binary( name = "tf_upgrade", - srcs = ["tf_upgrade.py"], + srcs = [ + "ast_edits.py", + "tf_upgrade.py", + ], srcs_version = "PY2AND3", ) py_test( name = "tf_upgrade_test", + size = "small", srcs = ["tf_upgrade_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/tools/compatibility/README.md b/tensorflow/tools/compatibility/README.md index d3bf2aa0324..aabc7b253d6 100644 --- a/tensorflow/tools/compatibility/README.md +++ b/tensorflow/tools/compatibility/README.md @@ -11,7 +11,10 @@ It will print a list of errors it finds that it can't fix. You can also run it on a directory tree: ``` -tf_upgrade.py --intree coolcode -outtree coolcode-upgraded +# just upgrade the .py files +tf_upgrade.py --intree coolcode --outtree coolcode-upgraded +# after upgrade the .py files, then copy all the other files to the outtree +tf_upgrade.py --intree coolcode --outtree coolcode-upgraded --copyotherfiles True ``` In either case, it will also dump out a report e.g. which will detail changes @@ -32,8 +35,8 @@ Renamed keyword argument from `squeeze_dims` to `axis` ## Caveats - Don't update parts of your code manually before running this script. In -particular, functions that have had reordered arguments like `tf.concat`, -`tf.split` will cause the script to incorrectly add keyword arguments that +particular, functions that have had reordered arguments like `tf.concat` +or `tf.split` will cause the script to incorrectly add keyword arguments that mismap arguments. - This script wouldn't actually reorder arguments. Instead, the script will add @@ -46,6 +49,12 @@ a tensor of bools. If the script detects this, it will report this to stdout `tf.reverse(a, [False, True, True])` you will need to manually change it to `tf.reverse(a, [1, 2])`. - - - +- There are some syntaxes that are not handleable with this script as this +script was designed to use only standard python packages. If the script fails +with "A necessary keyword argument failed to be inserted." or +"Failed to find keyword lexicographically. Fix manually.", you can try +[@machrisaa's fork of this script](https://github.com/machrisaa/tf0to1). +[@machrisaa](https://github.com/machrisaa) has used the +[RedBaron Python refactoring engine](https://redbaron.readthedocs.io/en/latest/) +which is able to localize syntactic elements more reliably than the built-in +`ast` module this script is based upon. diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py new file mode 100644 index 00000000000..e7e4c916921 --- /dev/null +++ b/tensorflow/tools/compatibility/ast_edits.py @@ -0,0 +1,497 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Upgrader for Python scripts according to an API change specification.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast +import collections +import os +import shutil +import sys +import tempfile +import traceback + + +class APIChangeSpec(object): + """This class defines the transformations that need to happen. + + This class must provide the following fields: + + * `function_keyword_renames`: maps function names to a map of old -> new + argument names + * `function_renames`: maps function names to new function names + * `change_to_function`: a set of function names that have changed (for + notifications) + * `function_reorders`: maps functions whose argument order has changed to the + list of arguments in the new order + * `function_handle`: maps function names to custom handlers for the function + + For an example, see `TFAPIChangeSpec`. + """ + + +class _FileEditTuple(collections.namedtuple( + "_FileEditTuple", ["comment", "line", "start", "old", "new"])): + """Each edit that is recorded by a _FileEditRecorder. + + Fields: + comment: A description of the edit and why it was made. + line: The line number in the file where the edit occurs (1-indexed). + start: The line number in the file where the edit occurs (0-indexed). + old: text string to remove (this must match what was in file). + new: text string to add in place of `old`. + """ + + __slots__ = () + + +class _FileEditRecorder(object): + """Record changes that need to be done to the file.""" + + def __init__(self, filename): + # all edits are lists of chars + self._filename = filename + + self._line_to_edit = collections.defaultdict(list) + self._errors = [] + + def process(self, text): + """Process a list of strings, each corresponding to the recorded changes. + + Args: + text: A list of lines of text (assumed to contain newlines) + Returns: + A tuple of the modified text and a textual description of what is done. + Raises: + ValueError: if substitution source location does not have expected text. + """ + + change_report = "" + + # Iterate of each line + for line, edits in self._line_to_edit.items(): + offset = 0 + # sort by column so that edits are processed in order in order to make + # indexing adjustments cumulative for changes that change the string + # length + edits.sort(key=lambda x: x.start) + + # Extract each line to a list of characters, because mutable lists + # are editable, unlike immutable strings. + char_array = list(text[line - 1]) + + # Record a description of the change + change_report += "%r Line %d\n" % (self._filename, line) + change_report += "-" * 80 + "\n\n" + for e in edits: + change_report += "%s\n" % e.comment + change_report += "\n Old: %s" % (text[line - 1]) + + # Make underscore buffers for underlining where in the line the edit was + change_list = [" "] * len(text[line - 1]) + change_list_new = [" "] * len(text[line - 1]) + + # Iterate for each edit + for e in edits: + # Create effective start, end by accounting for change in length due + # to previous edits + start_eff = e.start + offset + end_eff = start_eff + len(e.old) + + # Make sure the edit is changing what it should be changing + old_actual = "".join(char_array[start_eff:end_eff]) + if old_actual != e.old: + raise ValueError("Expected text %r but got %r" % + ("".join(e.old), "".join(old_actual))) + # Make the edit + char_array[start_eff:end_eff] = list(e.new) + + # Create the underline highlighting of the before and after + change_list[e.start:e.start + len(e.old)] = "~" * len(e.old) + change_list_new[start_eff:end_eff] = "~" * len(e.new) + + # Keep track of how to generate effective ranges + offset += len(e.new) - len(e.old) + + # Finish the report comment + change_report += " %s\n" % "".join(change_list) + text[line - 1] = "".join(char_array) + change_report += " New: %s" % (text[line - 1]) + change_report += " %s\n\n" % "".join(change_list_new) + return "".join(text), change_report, self._errors + + def add(self, comment, line, start, old, new, error=None): + """Add a new change that is needed. + + Args: + comment: A description of what was changed + line: Line number (1 indexed) + start: Column offset (0 indexed) + old: old text + new: new text + error: this "edit" is something that cannot be fixed automatically + Returns: + None + """ + + self._line_to_edit[line].append( + _FileEditTuple(comment, line, start, old, new)) + if error: + self._errors.append("%s:%d: %s" % (self._filename, line, error)) + + +class _ASTCallVisitor(ast.NodeVisitor): + """AST Visitor that processes function calls. + + Updates function calls from old API version to new API version using a given + change spec. + """ + + def __init__(self, filename, lines, api_change_spec): + self._filename = filename + self._file_edit = _FileEditRecorder(filename) + self._lines = lines + self._api_change_spec = api_change_spec + + def process(self, lines): + return self._file_edit.process(lines) + + def generic_visit(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def _rename_functions(self, node, full_name): + function_renames = self._api_change_spec.function_renames + try: + new_name = function_renames[full_name] + self._file_edit.add("Renamed function %r to %r" % (full_name, + new_name), + node.lineno, node.col_offset, full_name, new_name) + except KeyError: + pass + + def _get_attribute_full_path(self, node): + """Traverse an attribute to generate a full name e.g. tf.foo.bar. + + Args: + node: A Node of type Attribute. + + Returns: + a '.'-delimited full-name or None if the tree was not a simple form. + i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". + """ + curr = node + items = [] + while not isinstance(curr, ast.Name): + if not isinstance(curr, ast.Attribute): + return None + items.append(curr.attr) + curr = curr.value + items.append(curr.id) + return ".".join(reversed(items)) + + def _find_true_position(self, node): + """Return correct line number and column offset for a given node. + + This is necessary mainly because ListComp's location reporting reports + the next token after the list comprehension list opening. + + Args: + node: Node for which we wish to know the lineno and col_offset + """ + import re + find_open = re.compile("^\s*(\\[).*$") + find_string_chars = re.compile("['\"]") + + if isinstance(node, ast.ListComp): + # Strangely, ast.ListComp returns the col_offset of the first token + # after the '[' token which appears to be a bug. Workaround by + # explicitly finding the real start of the list comprehension. + line = node.lineno + col = node.col_offset + # loop over lines + while 1: + # Reverse the text to and regular expression search for whitespace + text = self._lines[line-1] + reversed_preceding_text = text[:col][::-1] + # First find if a [ can be found with only whitespace between it and + # col. + m = find_open.match(reversed_preceding_text) + if m: + new_col_offset = col - m.start(1) - 1 + return line, new_col_offset + else: + if (reversed_preceding_text=="" or + reversed_preceding_text.isspace()): + line = line - 1 + prev_line = self._lines[line - 1] + # TODO(aselle): + # this is poor comment detection, but it is good enough for + # cases where the comment does not contain string literal starting/ + # ending characters. If ast gave us start and end locations of the + # ast nodes rather than just start, we could use string literal + # node ranges to filter out spurious #'s that appear in string + # literals. + comment_start = prev_line.find("#") + if comment_start == -1: + col = len(prev_line) -1 + elif find_string_chars.search(prev_line[comment_start:]) is None: + col = comment_start + else: + return None, None + else: + return None, None + # Most other nodes return proper locations (with notably does not), but + # it is not possible to use that in an argument. + return node.lineno, node.col_offset + + + def visit_Call(self, node): # pylint: disable=invalid-name + """Handle visiting a call node in the AST. + + Args: + node: Current Node + """ + + + # Find a simple attribute name path e.g. "tf.foo.bar" + full_name = self._get_attribute_full_path(node.func) + + # Make sure the func is marked as being part of a call + node.func.is_function_for_call = True + + if full_name: + # Call special handlers + function_handles = self._api_change_spec.function_handle + if full_name in function_handles: + function_handles[full_name](self._file_edit, node) + + # Examine any non-keyword argument and make it into a keyword argument + # if reordering required. + function_reorders = self._api_change_spec.function_reorders + function_keyword_renames = ( + self._api_change_spec.function_keyword_renames) + + if full_name in function_reorders: + reordered = function_reorders[full_name] + for idx, arg in enumerate(node.args): + lineno, col_offset = self._find_true_position(arg) + if lineno is None or col_offset is None: + self._file_edit.add( + "Failed to add keyword %r to reordered function %r" + % (reordered[idx], full_name), arg.lineno, arg.col_offset, + "", "", + error="A necessary keyword argument failed to be inserted.") + else: + keyword_arg = reordered[idx] + if (full_name in function_keyword_renames and + keyword_arg in function_keyword_renames[full_name]): + keyword_arg = function_keyword_renames[full_name][keyword_arg] + self._file_edit.add("Added keyword %r to reordered function %r" + % (reordered[idx], full_name), lineno, + col_offset, "", keyword_arg + "=") + + # Examine each keyword argument and convert it to the final renamed form + renamed_keywords = ({} if full_name not in function_keyword_renames else + function_keyword_renames[full_name]) + for keyword in node.keywords: + argkey = keyword.arg + argval = keyword.value + + if argkey in renamed_keywords: + argval_lineno, argval_col_offset = self._find_true_position(argval) + if argval_lineno is not None and argval_col_offset is not None: + # TODO(aselle): We should scan backward to find the start of the + # keyword key. Unfortunately ast does not give you the location of + # keyword keys, so we are forced to infer it from the keyword arg + # value. + key_start = argval_col_offset - len(argkey) - 1 + key_end = key_start + len(argkey) + 1 + if (self._lines[argval_lineno - 1][key_start:key_end] == + argkey + "="): + self._file_edit.add("Renamed keyword argument from %r to %r" % + (argkey, renamed_keywords[argkey]), + argval_lineno, + argval_col_offset - len(argkey) - 1, + argkey + "=", renamed_keywords[argkey] + "=") + continue + self._file_edit.add( + "Failed to rename keyword argument from %r to %r" % + (argkey, renamed_keywords[argkey]), + argval.lineno, + argval.col_offset - len(argkey) - 1, + "", "", + error="Failed to find keyword lexographically. Fix manually.") + + ast.NodeVisitor.generic_visit(self, node) + + def visit_Attribute(self, node): # pylint: disable=invalid-name + """Handle bare Attributes i.e. [tf.foo, tf.bar]. + + Args: + node: Node that is of type ast.Attribute + """ + full_name = self._get_attribute_full_path(node) + if full_name: + self._rename_functions(node, full_name) + if full_name in self._api_change_spec.change_to_function: + if not hasattr(node, "is_function_for_call"): + new_text = full_name + "()" + self._file_edit.add("Changed %r to %r"%(full_name, new_text), + node.lineno, node.col_offset, full_name, new_text) + + ast.NodeVisitor.generic_visit(self, node) + + +class ASTCodeUpgrader(object): + """Handles upgrading a set of Python files using a given API change spec.""" + + def __init__(self, api_change_spec): + if not isinstance(api_change_spec, APIChangeSpec): + raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" % + type(api_change_spec)) + self._api_change_spec = api_change_spec + + def process_file(self, in_filename, out_filename): + """Process the given python file for incompatible changes. + + Args: + in_filename: filename to parse + out_filename: output file to write to + Returns: + A tuple representing number of files processed, log of actions, errors + """ + + # Write to a temporary file, just in case we are doing an implace modify. + with open(in_filename, "r") as in_file, \ + tempfile.NamedTemporaryFile("w", delete=False) as temp_file: + ret = self.process_opened_file( + in_filename, in_file, out_filename, temp_file) + + shutil.move(temp_file.name, out_filename) + return ret + + # Broad exceptions are required here because ast throws whatever it wants. + # pylint: disable=broad-except + def process_opened_file(self, in_filename, in_file, out_filename, out_file): + """Process the given python file for incompatible changes. + + This function is split out to facilitate StringIO testing from + tf_upgrade_test.py. + + Args: + in_filename: filename to parse + in_file: opened file (or StringIO) + out_filename: output file to write to + out_file: opened file (or StringIO) + Returns: + A tuple representing number of files processed, log of actions, errors + """ + process_errors = [] + text = "-" * 80 + "\n" + text += "Processing file %r\n outputting to %r\n" % (in_filename, + out_filename) + text += "-" * 80 + "\n\n" + + parsed_ast = None + lines = in_file.readlines() + try: + parsed_ast = ast.parse("".join(lines)) + except Exception: + text += "Failed to parse %r\n\n" % in_filename + text += traceback.format_exc() + if parsed_ast: + visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec) + visitor.visit(parsed_ast) + out_text, new_text, process_errors = visitor.process(lines) + text += new_text + if out_file: + out_file.write(out_text) + text += "\n" + return 1, text, process_errors + # pylint: enable=broad-except + + def process_tree(self, root_directory, output_root_directory, + copy_other_files): + """Processes upgrades on an entire tree of python files in place. + + Note that only Python files. If you have custom code in other languages, + you will need to manually upgrade those. + + Args: + root_directory: Directory to walk and process. + output_root_directory: Directory to use as base. + copy_other_files: Copy files that are not touched by this converter. + + Returns: + A tuple of files processed, the report string ofr all files, and errors + """ + + # make sure output directory doesn't exist + if output_root_directory and os.path.exists(output_root_directory): + print("Output directory %r must not already exist." % ( + output_root_directory)) + sys.exit(1) + + # make sure output directory does not overlap with root_directory + norm_root = os.path.split(os.path.normpath(root_directory)) + norm_output = os.path.split(os.path.normpath(output_root_directory)) + if norm_root == norm_output: + print("Output directory %r same as input directory %r" % ( + root_directory, output_root_directory)) + sys.exit(1) + + # Collect list of files to process (we do this to correctly handle if the + # user puts the output directory in some sub directory of the input dir) + files_to_process = [] + files_to_copy = [] + for dir_name, _, file_list in os.walk(root_directory): + py_files = [f for f in file_list if f.endswith(".py")] + copy_files = [f for f in file_list if not f.endswith(".py")] + for filename in py_files: + fullpath = os.path.join(dir_name, filename) + fullpath_output = os.path.join( + output_root_directory, os.path.relpath(fullpath, root_directory)) + files_to_process.append((fullpath, fullpath_output)) + if copy_other_files: + for filename in copy_files: + fullpath = os.path.join(dir_name, filename) + fullpath_output = os.path.join( + output_root_directory, os.path.relpath(fullpath, root_directory)) + files_to_copy.append((fullpath, fullpath_output)) + + file_count = 0 + tree_errors = [] + report = "" + report += ("=" * 80) + "\n" + report += "Input tree: %r\n" % root_directory + report += ("=" * 80) + "\n" + + for input_path, output_path in files_to_process: + output_directory = os.path.dirname(output_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + file_count += 1 + _, l_report, l_errors = self.process_file(input_path, output_path) + tree_errors += l_errors + report += l_report + for input_path, output_path in files_to_copy: + output_directory = os.path.dirname(output_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + shutil.copy(input_path, output_path) + return file_count, report, tree_errors diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index bcff10f21d5..72fe4a48cdd 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -17,23 +17,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + import argparse -import ast -import collections -import os -import shutil -import sys -import tempfile -import traceback + +from tensorflow.tools.compatibility import ast_edits -class APIChangeSpec(object): +class TFAPIChangeSpec(ast_edits.APIChangeSpec): """List of maps that describe what changed in the API.""" def __init__(self): # Maps from a function name to a dictionary that describes how to # map from an old argument keyword to the new argument keyword. self.function_keyword_renames = { + "tf.batch_matmul": { + "adj_x": "adjoint_a", + "adj_y": "adjoint_b", + }, "tf.count_nonzero": { "reduction_indices": "axis" }, @@ -140,6 +140,7 @@ class APIChangeSpec(object): "tf.batch_svd": "tf.svd", "tf.batch_fft": "tf.fft", "tf.batch_ifft": "tf.ifft", + "tf.batch_fft2d": "tf.fft2d", "tf.batch_ifft2d": "tf.ifft2d", "tf.batch_fft3d": "tf.fft3d", "tf.batch_ifft3d": "tf.ifft3d", @@ -148,6 +149,7 @@ class APIChangeSpec(object): "tf.batch_matmul": "tf.matmul", "tf.pack": "tf.stack", "tf.unpack": "tf.unstack", + "tf.op_scope": "tf.name_scope", } self.change_to_function = { @@ -168,11 +170,14 @@ class APIChangeSpec(object): "tf.nn.sparse_softmax_cross_entropy_with_logits": [ "logits", "labels", "name"], "tf.nn.sigmoid_cross_entropy_with_logits": [ - "logits", "labels", "name"] + "logits", "labels", "name"], + "tf.op_scope": ["values", "name", "default_name"], } # Specially handled functions. - self.function_handle = {"tf.reverse": self._reverse_handler} + self.function_handle = { + "tf.reverse": self._reverse_handler + } @staticmethod def _reverse_handler(file_edit_recorder, node): @@ -189,437 +194,6 @@ class APIChangeSpec(object): error="tf.reverse requires manual check.") -class FileEditTuple(collections.namedtuple( - "FileEditTuple", ["comment", "line", "start", "old", "new"])): - """Each edit that is recorded by a FileEditRecorder. - - Fields: - comment: A description of the edit and why it was made. - line: The line number in the file where the edit occurs (1-indexed). - start: The line number in the file where the edit occurs (0-indexed). - old: text string to remove (this must match what was in file). - new: text string to add in place of `old`. - """ - - __slots__ = () - - -class FileEditRecorder(object): - """Record changes that need to be done to the file.""" - - def __init__(self, filename): - # all edits are lists of chars - self._filename = filename - - self._line_to_edit = collections.defaultdict(list) - self._errors = [] - - def process(self, text): - """Process a list of strings, each corresponding to the recorded changes. - - Args: - text: A list of lines of text (assumed to contain newlines) - Returns: - A tuple of the modified text and a textual description of what is done. - Raises: - ValueError: if substitution source location does not have expected text. - """ - - change_report = "" - - # Iterate of each line - for line, edits in self._line_to_edit.items(): - offset = 0 - # sort by column so that edits are processed in order in order to make - # indexing adjustments cumulative for changes that change the string - # length - edits.sort(key=lambda x: x.start) - - # Extract each line to a list of characters, because mutable lists - # are editable, unlike immutable strings. - char_array = list(text[line - 1]) - - # Record a description of the change - change_report += "%r Line %d\n" % (self._filename, line) - change_report += "-" * 80 + "\n\n" - for e in edits: - change_report += "%s\n" % e.comment - change_report += "\n Old: %s" % (text[line - 1]) - - # Make underscore buffers for underlining where in the line the edit was - change_list = [" "] * len(text[line - 1]) - change_list_new = [" "] * len(text[line - 1]) - - # Iterate for each edit - for e in edits: - # Create effective start, end by accounting for change in length due - # to previous edits - start_eff = e.start + offset - end_eff = start_eff + len(e.old) - - # Make sure the edit is changing what it should be changing - old_actual = "".join(char_array[start_eff:end_eff]) - if old_actual != e.old: - raise ValueError("Expected text %r but got %r" % - ("".join(e.old), "".join(old_actual))) - # Make the edit - char_array[start_eff:end_eff] = list(e.new) - - # Create the underline highlighting of the before and after - change_list[e.start:e.start + len(e.old)] = "~" * len(e.old) - change_list_new[start_eff:end_eff] = "~" * len(e.new) - - # Keep track of how to generate effective ranges - offset += len(e.new) - len(e.old) - - # Finish the report comment - change_report += " %s\n" % "".join(change_list) - text[line - 1] = "".join(char_array) - change_report += " New: %s" % (text[line - 1]) - change_report += " %s\n\n" % "".join(change_list_new) - return "".join(text), change_report, self._errors - - def add(self, comment, line, start, old, new, error=None): - """Add a new change that is needed. - - Args: - comment: A description of what was changed - line: Line number (1 indexed) - start: Column offset (0 indexed) - old: old text - new: new text - error: this "edit" is something that cannot be fixed automatically - Returns: - None - """ - - self._line_to_edit[line].append( - FileEditTuple(comment, line, start, old, new)) - if error: - self._errors.append("%s:%d: %s" % (self._filename, line, error)) - - -class TensorFlowCallVisitor(ast.NodeVisitor): - """AST Visitor that finds TensorFlow Function calls. - - Updates function calls from old API version to new API version. - """ - - def __init__(self, filename, lines): - self._filename = filename - self._file_edit = FileEditRecorder(filename) - self._lines = lines - self._api_change_spec = APIChangeSpec() - - def process(self, lines): - return self._file_edit.process(lines) - - def generic_visit(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def _rename_functions(self, node, full_name): - function_renames = self._api_change_spec.function_renames - try: - new_name = function_renames[full_name] - self._file_edit.add("Renamed function %r to %r" % (full_name, - new_name), - node.lineno, node.col_offset, full_name, new_name) - except KeyError: - pass - - def _get_attribute_full_path(self, node): - """Traverse an attribute to generate a full name e.g. tf.foo.bar. - - Args: - node: A Node of type Attribute. - - Returns: - a '.'-delimited full-name or None if the tree was not a simple form. - i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". - """ - curr = node - items = [] - while not isinstance(curr, ast.Name): - if not isinstance(curr, ast.Attribute): - return None - items.append(curr.attr) - curr = curr.value - items.append(curr.id) - return ".".join(reversed(items)) - - def _find_true_position(self, node): - """Return correct line number and column offset for a given node. - - This is necessary mainly because ListComp's location reporting reports - the next token after the list comprehension list opening. - - Args: - node: Node for which we wish to know the lineno and col_offset - """ - import re - find_open = re.compile("^\s*(\\[).*$") - find_string_chars = re.compile("['\"]") - - if isinstance(node, ast.ListComp): - # Strangely, ast.ListComp returns the col_offset of the first token - # after the '[' token which appears to be a bug. Workaround by - # explicitly finding the real start of the list comprehension. - line = node.lineno - col = node.col_offset - # loop over lines - while 1: - # Reverse the text to and regular expression search for whitespace - text = self._lines[line-1] - reversed_preceding_text = text[:col][::-1] - # First find if a [ can be found with only whitespace between it and - # col. - m = find_open.match(reversed_preceding_text) - if m: - new_col_offset = col - m.start(1) - 1 - return line, new_col_offset - else: - if (reversed_preceding_text=="" or - reversed_preceding_text.isspace()): - line = line - 1 - prev_line = self._lines[line - 1] - # TODO(aselle): - # this is poor comment detection, but it is good enough for - # cases where the comment does not contain string literal starting/ - # ending characters. If ast gave us start and end locations of the - # ast nodes rather than just start, we could use string literal - # node ranges to filter out spurious #'s that appear in string - # literals. - comment_start = prev_line.find("#") - if comment_start == -1: - col = len(prev_line) -1 - elif find_string_chars.search(prev_line[comment_start:]) is None: - col = comment_start - else: - return None, None - else: - return None, None - # Most other nodes return proper locations (with notably does not), but - # it is not possible to use that in an argument. - return node.lineno, node.col_offset - - - def visit_Call(self, node): # pylint: disable=invalid-name - """Handle visiting a call node in the AST. - - Args: - node: Current Node - """ - - - # Find a simple attribute name path e.g. "tf.foo.bar" - full_name = self._get_attribute_full_path(node.func) - - # Make sure the func is marked as being part of a call - node.func.is_function_for_call = True - - if full_name and full_name.startswith("tf."): - # Call special handlers - function_handles = self._api_change_spec.function_handle - if full_name in function_handles: - function_handles[full_name](self._file_edit, node) - - # Examine any non-keyword argument and make it into a keyword argument - # if reordering required. - function_reorders = self._api_change_spec.function_reorders - function_keyword_renames = ( - self._api_change_spec.function_keyword_renames) - - if full_name in function_reorders: - reordered = function_reorders[full_name] - for idx, arg in enumerate(node.args): - lineno, col_offset = self._find_true_position(arg) - if lineno is None or col_offset is None: - self._file_edit.add( - "Failed to add keyword %r to reordered function %r" - % (reordered[idx], full_name), arg.lineno, arg.col_offset, - "", "", - error="A necessary keyword argument failed to be inserted.") - else: - keyword_arg = reordered[idx] - if (full_name in function_keyword_renames and - keyword_arg in function_keyword_renames[full_name]): - keyword_arg = function_keyword_renames[full_name][keyword_arg] - self._file_edit.add("Added keyword %r to reordered function %r" - % (reordered[idx], full_name), lineno, - col_offset, "", keyword_arg + "=") - - # Examine each keyword argument and convert it to the final renamed form - renamed_keywords = ({} if full_name not in function_keyword_renames else - function_keyword_renames[full_name]) - for keyword in node.keywords: - argkey = keyword.arg - argval = keyword.value - - if argkey in renamed_keywords: - argval_lineno, argval_col_offset = self._find_true_position(argval) - if (argval_lineno is not None and argval_col_offset is not None): - # TODO(aselle): We should scan backward to find the start of the - # keyword key. Unfortunately ast does not give you the location of - # keyword keys, so we are forced to infer it from the keyword arg - # value. - key_start = argval_col_offset - len(argkey) - 1 - key_end = key_start + len(argkey) + 1 - if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=": - self._file_edit.add("Renamed keyword argument from %r to %r" % - (argkey, renamed_keywords[argkey]), - argval_lineno, - argval_col_offset - len(argkey) - 1, - argkey + "=", renamed_keywords[argkey] + "=") - continue - self._file_edit.add( - "Failed to rename keyword argument from %r to %r" % - (argkey, renamed_keywords[argkey]), - argval.lineno, - argval.col_offset - len(argkey) - 1, - "", "", - error="Failed to find keyword lexographically. Fix manually.") - - ast.NodeVisitor.generic_visit(self, node) - - def visit_Attribute(self, node): # pylint: disable=invalid-name - """Handle bare Attributes i.e. [tf.foo, tf.bar]. - - Args: - node: Node that is of type ast.Attribute - """ - full_name = self._get_attribute_full_path(node) - if full_name and full_name.startswith("tf."): - self._rename_functions(node, full_name) - if full_name in self._api_change_spec.change_to_function: - if not hasattr(node, "is_function_for_call"): - new_text = full_name + "()" - self._file_edit.add("Changed %r to %r"%(full_name, new_text), - node.lineno, node.col_offset, full_name, new_text) - - ast.NodeVisitor.generic_visit(self, node) - - -class TensorFlowCodeUpgrader(object): - """Class that handles upgrading a set of Python files to TensorFlow 1.0.""" - - def __init__(self): - pass - - def process_file(self, in_filename, out_filename): - """Process the given python file for incompatible changes. - - Args: - in_filename: filename to parse - out_filename: output file to write to - Returns: - A tuple representing number of files processed, log of actions, errors - """ - - # Write to a temporary file, just in case we are doing an implace modify. - with open(in_filename, "r") as in_file, \ - tempfile.NamedTemporaryFile("w", delete=False) as temp_file: - ret = self.process_opened_file( - in_filename, in_file, out_filename, temp_file) - - shutil.move(temp_file.name, out_filename) - return ret - - # Broad exceptions are required here because ast throws whatever it wants. - # pylint: disable=broad-except - def process_opened_file(self, in_filename, in_file, out_filename, out_file): - """Process the given python file for incompatible changes. - - This function is split out to facilitate StringIO testing from - tf_upgrade_test.py. - - Args: - in_filename: filename to parse - in_file: opened file (or StringIO) - out_filename: output file to write to - out_file: opened file (or StringIO) - Returns: - A tuple representing number of files processed, log of actions, errors - """ - process_errors = [] - text = "-" * 80 + "\n" - text += "Processing file %r\n outputting to %r\n" % (in_filename, - out_filename) - text += "-" * 80 + "\n\n" - - parsed_ast = None - lines = in_file.readlines() - try: - parsed_ast = ast.parse("".join(lines)) - except Exception: - text += "Failed to parse %r\n\n" % in_filename - text += traceback.format_exc() - if parsed_ast: - visitor = TensorFlowCallVisitor(in_filename, lines) - visitor.visit(parsed_ast) - out_text, new_text, process_errors = visitor.process(lines) - text += new_text - if out_file: - out_file.write(out_text) - text += "\n" - return 1, text, process_errors - # pylint: enable=broad-except - - def process_tree(self, root_directory, output_root_directory): - """Processes upgrades on an entire tree of python files in place. - - Note that only Python files. If you have custom code in other languages, - you will need to manually upgrade those. - - Args: - root_directory: Directory to walk and process. - output_root_directory: Directory to use as base - Returns: - A tuple of files processed, the report string ofr all files, and errors - """ - - # make sure output directory doesn't exist - if output_root_directory and os.path.exists(output_root_directory): - print("Output directory %r must not already exist." % ( - output_root_directory)) - sys.exit(1) - - # make sure output directory does not overlap with root_directory - norm_root = os.path.split(os.path.normpath(root_directory)) - norm_output = os.path.split(os.path.normpath(output_root_directory)) - if norm_root == norm_output: - print("Output directory %r same as input directory %r" % ( - root_directory, output_root_directory)) - sys.exit(1) - - # Collect list of files to process (we do this to correctly handle if the - # user puts the output directory in some sub directory of the input dir) - files_to_process = [] - for dir_name, _, file_list in os.walk(root_directory): - py_files = [f for f in file_list if f.endswith(".py")] - for filename in py_files: - fullpath = os.path.join(dir_name, filename) - fullpath_output = os.path.join( - output_root_directory, os.path.relpath(fullpath, root_directory)) - files_to_process.append((fullpath, fullpath_output)) - - file_count = 0 - tree_errors = [] - report = "" - report += ("=" * 80) + "\n" - report += "Input tree: %r\n" % root_directory - report += ("=" * 80) + "\n" - - for input_path, output_path in files_to_process: - output_directory = os.path.dirname(output_path) - if not os.path.isdir(output_directory): - os.makedirs(output_directory) - file_count += 1 - _, l_report, l_errors = self.process_file(input_path, output_path) - tree_errors += l_errors - report += l_report - return file_count, report, tree_errors - - if __name__ == "__main__": parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, @@ -648,6 +222,13 @@ Simple usage: dest="output_tree", help="If converting a whole tree of files, the output " "directory (relative or absolute).") + parser.add_argument( + "--copyotherfiles", + dest="copy_other_files", + help=("If converting a whole tree of files, whether to " + "copy the other files."), + type=bool, + default=False) parser.add_argument( "--reportfile", dest="report_filename", @@ -657,7 +238,7 @@ Simple usage: default="report.txt") args = parser.parse_args() - upgrade = TensorFlowCodeUpgrader() + upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec()) report_text = None report_filename = args.report_filename files_processed = 0 @@ -667,7 +248,7 @@ Simple usage: files_processed = 1 elif args.input_tree: files_processed, report_text, errors = upgrade.process_tree( - args.input_tree, args.output_tree) + args.input_tree, args.output_tree, args.copy_other_files) else: parser.print_help() if report_text: diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py index de4e3de73cd..ac838a2791f 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_test.py @@ -22,6 +22,7 @@ import tempfile import six from tensorflow.python.framework import test_util from tensorflow.python.platform import test as test_lib +from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import tf_upgrade @@ -36,7 +37,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): def _upgrade(self, old_file_text): in_file = six.StringIO(old_file_text) out_file = six.StringIO() - upgrader = tf_upgrade.TensorFlowCodeUpgrader() + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) count, report, errors = ( upgrader.process_opened_file("test.py", in_file, "test_out.py", out_file)) @@ -139,7 +140,7 @@ class TestUpgradeFiles(test_util.TensorFlowTestCase): upgraded = "tf.multiply(a, b)\n" temp_file.write(original) temp_file.close() - upgrader = tf_upgrade.TensorFlowCodeUpgrader() + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) upgrader.process_file(temp_file.name, temp_file.name) self.assertAllEqual(open(temp_file.name).read(), upgraded) os.unlink(temp_file.name) diff --git a/tensorflow/tools/dist_test/Dockerfile b/tensorflow/tools/dist_test/Dockerfile index 65d7e1717e7..83bbeeca8a9 100644 --- a/tensorflow/tools/dist_test/Dockerfile +++ b/tensorflow/tools/dist_test/Dockerfile @@ -23,7 +23,7 @@ FROM ubuntu:16.04 MAINTAINER Shanqing Cai RUN apt-get update -RUN apt-get install -y --no-install-recommends \ +RUN apt-get install -y \ curl \ python \ python-numpy \ diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh index f9f37ff0e11..7d7f92d246e 100755 --- a/tensorflow/tools/dist_test/local_test.sh +++ b/tensorflow/tools/dist_test/local_test.sh @@ -70,7 +70,7 @@ get_container_id_by_image_name() { # Get the id of a container by image name # Usage: get_docker_container_id_by_image_name - echo $(docker ps | grep $1 | awk '{print $1}') + docker ps | grep $1 | awk '{print $1}' } # Parse input arguments @@ -151,6 +151,8 @@ rm -rf "${BUILD_DIR}" # Run docker image for test. docker run ${DOCKER_IMG_NAME} \ /var/tf_dist_test/scripts/dist_mnist_test.sh \ - --ps_hosts "localhost:2000,localhost:2001" \ - --worker_hosts "localhost:3000,localhost:3001" \ + --ps_hosts $(seq -f "localhost:%g" -s "," \ + 2000 $((2000 + NUM_PARAMETER_SERVERS - 1))) \ + --worker_hosts $(seq -f "localhost:%g" -s "," \ + 3000 $((3000 + NUM_WORKERS - 1))) \ --num_gpus 0 ${SYNC_REPLICAS_FLAG} diff --git a/tensorflow/tools/dist_test/python/census_widendeep.py b/tensorflow/tools/dist_test/python/census_widendeep.py index db56a687f6b..3a557814960 100644 --- a/tensorflow/tools/dist_test/python/census_widendeep.py +++ b/tensorflow/tools/dist_test/python/census_widendeep.py @@ -133,7 +133,7 @@ class CensusDataSource(object): columns: Columns to retrieve from the data files (A list of strings) label_column: Name of the label column categorical_columns: Names of the categorical columns (A list of strings) - continuous_columns: Names of the continuous columsn (A list of strings) + continuous_columns: Names of the continuous columns (A list of strings) """ # Retrieve data from disk (if available) or download from the web. diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py index 7e68258b0a0..f7dbfea7fb0 100644 --- a/tensorflow/tools/dist_test/python/mnist_replica.py +++ b/tensorflow/tools/dist_test/python/mnist_replica.py @@ -16,9 +16,9 @@ """Distributed MNIST training and validation, with model replicas. A simple softmax model with one hidden layer is defined. The parameters -(weights and biases) are located on two parameter servers (ps), while the -ops are defined on a worker node. The TF sessions also run on the worker -node. +(weights and biases) are located on one parameter server (ps), while the ops +are executed on two worker nodes by default. The TF sessions also run on the +worker node. Multiple invocations of this script can be done in parallel, with different values for --task_index. There should be exactly one invocation with --task_index, which will create a master session that carries out variable diff --git a/tensorflow/tools/dist_test/scripts/BUILD b/tensorflow/tools/dist_test/scripts/BUILD new file mode 100644 index 00000000000..c329f0bbe87 --- /dev/null +++ b/tensorflow/tools/dist_test/scripts/BUILD @@ -0,0 +1,22 @@ +# Tools for running distributed benchmarks. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["k8s_tensorflow.py"]) + +py_library( + name = "k8s_tensorflow_lib", + srcs = ["k8s_tensorflow_lib.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "k8s_tensorflow_test", + size = "small", + srcs = ["k8s_tensorflow_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":k8s_tensorflow_lib", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py index 854c6b832a7..b325f030e36 100755 --- a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py +++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py @@ -25,6 +25,8 @@ from __future__ import print_function import argparse import sys +import k8s_tensorflow_lib + # Note: It is intentional that we do not import tensorflow in this script. The # machine that launches a TensorFlow k8s cluster does not have to have the # Python package of TensorFlow installed on it. @@ -33,127 +35,12 @@ import sys DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server' DEFAULT_PORT = 2222 -# TODO(cais): Consider adding resource requests/limits to the pods. - -# Worker pods will mount host volume /shared, as a convenient way to create -# shared storage among workers during local tests. -WORKER_RC = ( - """apiVersion: v1 -kind: ReplicationController -metadata: - name: tf-worker{worker_id} -spec: - replicas: 1 - template: - metadata: - labels: - tf-worker: "{worker_id}" - spec: - containers: - - name: tf-worker{worker_id} - image: {docker_image} - args: - - --cluster_spec={cluster_spec} - - --job_name=worker - - --task_id={worker_id} - ports: - - containerPort: {port} - volumeMounts: - - name: shared - mountPath: /shared - volumes: - - name: shared - hostPath: - path: /shared -""") -WORKER_SVC = ( - """apiVersion: v1 -kind: Service -metadata: - name: tf-worker{worker_id} - labels: - tf-worker: "{worker_id}" -spec: - ports: - - port: {port} - targetPort: {port} - selector: - tf-worker: "{worker_id}" -""") -WORKER_LB_SVC = ( - """apiVersion: v1 -kind: Service -metadata: - name: tf-worker{worker_id} - labels: - tf-worker: "{worker_id}" -spec: - type: LoadBalancer - ports: - - port: {port} - selector: - tf-worker: "{worker_id}" -""") -PARAM_SERVER_RC = ( - """apiVersion: v1 -kind: ReplicationController -metadata: - name: tf-ps{param_server_id} -spec: - replicas: 1 - template: - metadata: - labels: - tf-ps: "{param_server_id}" - spec: - containers: - - name: tf-ps{param_server_id} - image: {docker_image} - args: - - --cluster_spec={cluster_spec} - - --job_name=ps - - --task_id={param_server_id} - ports: - - containerPort: {port} - volumeMounts: - - name: shared - mountPath: /shared - volumes: - - name: shared - hostPath: - path: /shared -""") -PARAM_SERVER_SVC = ( - """apiVersion: v1 -kind: Service -metadata: - name: tf-ps{param_server_id} - labels: - tf-ps: "{param_server_id}" -spec: - ports: - - port: {port} - selector: - tf-ps: "{param_server_id}" -""") -PARAM_LB_SVC = ("""apiVersion: v1 -kind: Service -metadata: - name: tf-ps{param_server_id} - labels: - tf-ps: "{param_server_id}" -spec: - type: LoadBalancer - ports: - - port: {port} - selector: - tf-ps: "{param_server_id}" -""") - def main(): """Do arg parsing.""" parser = argparse.ArgumentParser() + parser.register( + 'type', 'bool', lambda v: v.lower() in ('true', 't', 'y', 'yes')) parser.add_argument('--num_workers', type=int, default=2, @@ -167,7 +54,7 @@ def main(): default=DEFAULT_PORT, help='GRPC server port (Default: %d)' % DEFAULT_PORT) parser.add_argument('--request_load_balancer', - type=bool, + type='bool', default=False, help='To request worker0 to be exposed on a public IP ' 'address via an external load balancer, enabling you to ' @@ -177,6 +64,16 @@ def main(): default=DEFAULT_DOCKER_IMAGE, help='Override default docker image for the TensorFlow ' 'GRPC server') + parser.add_argument('--name_prefix', + type=str, + default='tf', + help='Prefix for job names. Jobs will be named as ' + '_worker|ps') + parser.add_argument('--use_shared_volume', + type='bool', + default=True, + help='Whether to mount /shared directory from host to ' + 'the pod') args = parser.parse_args() if args.num_workers <= 0: @@ -190,88 +87,17 @@ def main(): sys.exit(1) # Generate contents of yaml config - yaml_config = GenerateConfig(args.num_workers, - args.num_parameter_servers, - args.grpc_port, - args.request_load_balancer, - args.docker_image) + yaml_config = k8s_tensorflow_lib.GenerateConfig( + args.num_workers, + args.num_parameter_servers, + args.grpc_port, + args.request_load_balancer, + args.docker_image, + args.name_prefix, + env_vars=None, + use_shared_volume=args.use_shared_volume) print(yaml_config) # pylint: disable=superfluous-parens -def GenerateConfig(num_workers, - num_param_servers, - port, - request_load_balancer, - docker_image): - """Generate configuration strings.""" - config = '' - for worker in range(num_workers): - config += WORKER_RC.format( - port=port, - worker_id=worker, - docker_image=docker_image, - cluster_spec=WorkerClusterSpecString(num_workers, - num_param_servers, - port)) - config += '---\n' - if request_load_balancer: - config += WORKER_LB_SVC.format(port=port, - worker_id=worker) - else: - config += WORKER_SVC.format(port=port, - worker_id=worker) - config += '---\n' - - for param_server in range(num_param_servers): - config += PARAM_SERVER_RC.format( - port=port, - param_server_id=param_server, - docker_image=docker_image, - cluster_spec=ParamServerClusterSpecString(num_workers, - num_param_servers, - port)) - config += '---\n' - if request_load_balancer: - config += PARAM_LB_SVC.format(port=port, param_server_id=param_server) - else: - config += PARAM_SERVER_SVC.format(port=port, param_server_id=param_server) - config += '---\n' - - return config - - -def WorkerClusterSpecString(num_workers, - num_param_servers, - port): - """Generates worker cluster spec.""" - return ClusterSpecString(num_workers, num_param_servers, port) - - -def ParamServerClusterSpecString(num_workers, - num_param_servers, - port): - """Generates parameter server spec.""" - return ClusterSpecString(num_workers, num_param_servers, port) - - -def ClusterSpecString(num_workers, - num_param_servers, - port): - """Generates general cluster spec.""" - spec = 'worker|' - for worker in range(num_workers): - spec += 'tf-worker%d:%d' % (worker, port) - if worker != num_workers-1: - spec += ';' - - spec += ',ps|' - for param_server in range(num_param_servers): - spec += 'tf-ps%d:%d' % (param_server, port) - if param_server != num_param_servers-1: - spec += ';' - - return spec - - if __name__ == '__main__': main() diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py new file mode 100644 index 00000000000..8adbe387ba3 --- /dev/null +++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py @@ -0,0 +1,309 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Generates YAML configuration files for distributed TensorFlow workers. + +The workers will be run in a Kubernetes (k8s) container cluster. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Note: It is intentional that we do not import tensorflow in this script. The +# machine that launches a TensorFlow k8s cluster does not have to have the +# Python package of TensorFlow installed on it. + +# TODO(cais): Consider adding resource requests/limits to the pods. + +# Worker pods will mount host volume /shared, as a convenient way to create +# shared storage among workers during local tests. +WORKER_RC = ( + """apiVersion: v1 +kind: ReplicationController +metadata: + name: {name_prefix}-worker{worker_id} +spec: + replicas: 1 + template: + metadata: + labels: + tf-worker: "{worker_id}" + name-prefix: "{name_prefix}" + job: "worker" + spec: + containers: + - name: tf-worker{worker_id} + image: {docker_image} + args: [{args}] + ports: + - containerPort: {port} + env: [{env_vars}] + volumeMounts: [{volume_mounts}] + volumes: [{volumes}] +""") +WORKER_SVC = ( + """apiVersion: v1 +kind: Service +metadata: + name: {name_prefix}-worker{worker_id} + labels: + tf-worker: "{worker_id}" +spec: + ports: + - port: {port} + targetPort: {port} + selector: + tf-worker: "{worker_id}" +""") +WORKER_LB_SVC = ( + """apiVersion: v1 +kind: Service +metadata: + name: {name_prefix}-worker{worker_id} + labels: + tf-worker: "{worker_id}" +spec: + type: LoadBalancer + ports: + - port: {port} + selector: + tf-worker: "{worker_id}" +""") +PARAM_SERVER_RC = ( + """apiVersion: v1 +kind: ReplicationController +metadata: + name: {name_prefix}-ps{param_server_id} +spec: + replicas: 1 + template: + metadata: + labels: + tf-ps: "{param_server_id}" + name-prefix: "{name_prefix}" + job: "ps" + spec: + containers: + - name: tf-ps{param_server_id} + image: {docker_image} + args: [{args}] + ports: + - containerPort: {port} + env: [{env_vars}] + volumeMounts: [{volume_mounts}] + volumes: [{volumes}] +""") +PARAM_SERVER_SVC = ( + """apiVersion: v1 +kind: Service +metadata: + name: {name_prefix}-ps{param_server_id} + labels: + tf-ps: "{param_server_id}" +spec: + ports: + - port: {port} + selector: + tf-ps: "{param_server_id}" +""") +PARAM_LB_SVC = ("""apiVersion: v1 +kind: Service +metadata: + name: {name_prefix}-ps{param_server_id} + labels: + tf-ps: "{param_server_id}" +spec: + type: LoadBalancer + ports: + - port: {port} + selector: + tf-ps: "{param_server_id}" +""") +VOLUME_MOUNTS = '{name: shared, mountPath: /shared}' +VOLUMES = '{name: shared, hostPath: {path: /shared}}' +_ENV_VAR_TEMPLATE = '{name: "%s", value: "%s"}' +_ARG_TEMPLATE = '"--%s=%s"' + + +def GenerateConfig(num_workers, + num_param_servers, + port, + request_load_balancer, + docker_image, + name_prefix, + env_vars=None, + use_shared_volume=True, + use_cluster_spec=True): + """Generate configuration strings. + + Args: + num_workers: number of worker jobs. + num_param_servers: number of ps server jobs. + port: GRPC server port. + request_load_balancer: request worker0 to be exposed on a public IP + address via an external load balancer. + docker_image: docker image to use. + name_prefix: name to prepend to pod job names. + env_vars: dictionary of environment variables to set. + use_shared_volume: whether to add hostPath to /shared directory + to the kubernetes config. + use_cluster_spec: if true, pass --cluster_spec to worker and ps jobs. + If false, pass --worker_hosts and --ps_hosts to worker and ps jobs. + + Returns: + Kubernetes yaml config. + """ + if env_vars is None: + env_vars = {} + env_str = ', '.join([_ENV_VAR_TEMPLATE % (name, value) + for name, value in env_vars.items()]) + config = '' + common_args = GetCommonArgs( + num_workers, num_param_servers, port, name_prefix, use_cluster_spec) + for worker in range(num_workers): + worker_args = { + 'job_name': 'worker', + 'task_id': worker + } + worker_args.update(common_args) + arg_str = ', '.join([_ARG_TEMPLATE % (name, value) + for name, value in worker_args.items()]) + config += WORKER_RC.format( + port=port, + worker_id=worker, + docker_image=docker_image, + name_prefix=name_prefix, + volume_mounts=VOLUME_MOUNTS if use_shared_volume else '', + volumes=VOLUMES if use_shared_volume else '', + args=arg_str, + env_vars=env_str) + config += '---\n' + if request_load_balancer: + config += WORKER_LB_SVC.format(port=port, + worker_id=worker, + name_prefix=name_prefix) + else: + config += WORKER_SVC.format(port=port, + worker_id=worker, + name_prefix=name_prefix) + config += '---\n' + + for param_server in range(num_param_servers): + ps_args = { + 'job_name': 'ps', + 'task_id': param_server + } + ps_args.update(common_args) + arg_str = ', '.join([_ARG_TEMPLATE % (name, value) + for name, value in ps_args.items()]) + config += PARAM_SERVER_RC.format( + port=port, + param_server_id=param_server, + docker_image=docker_image, + name_prefix=name_prefix, + volume_mounts=VOLUME_MOUNTS if use_shared_volume else '', + volumes=VOLUMES if use_shared_volume else '', + args=arg_str, + env_vars=env_str) + config += '---\n' + if request_load_balancer: + config += PARAM_LB_SVC.format( + port=port, param_server_id=param_server, name_prefix=name_prefix) + else: + config += PARAM_SERVER_SVC.format( + port=port, param_server_id=param_server, name_prefix=name_prefix) + config += '---\n' + + return config + + +def WorkerClusterSpecString(num_workers, + num_param_servers, + port, + name_prefix): + """Generates worker cluster spec.""" + return ClusterSpecString(num_workers, num_param_servers, port, name_prefix) + + +def ParamServerClusterSpecString(num_workers, + num_param_servers, + port, + name_prefix): + """Generates parameter server spec.""" + return ClusterSpecString(num_workers, num_param_servers, port, + name_prefix) + + +def ClusterSpecString(num_workers, + num_param_servers, + port, + name_prefix): + """Generates general cluster spec.""" + spec = 'worker|' + for worker in range(num_workers): + spec += '%s-worker%d:%d' % (name_prefix, worker, port) + if worker != num_workers-1: + spec += ';' + + spec += ',ps|' + for param_server in range(num_param_servers): + spec += '%s-ps%d:%d' % (name_prefix, param_server, port) + if param_server != num_param_servers-1: + spec += ';' + + return spec + + +def GetCommonArgs(num_workers, + num_param_servers, + port, + name_prefix, + use_cluster_spec): + """Get arguments common to both worker and ps jobs. + + Args: + num_workers: number of workers. + num_param_servers: number of ps servers. + port: worker and ps port number. + name_prefix: prefix to prepend to job names. + use_cluster_spec: if true, pass --cluster_spec argument. + If false, parse --worker_hosts and --ps_hosts arguments. + + Returns: + A dictionary of argument names mapping to argument values. + """ + common_args = {} + if use_cluster_spec: + common_args['cluster_spec'] = WorkerClusterSpecString( + num_workers, + num_param_servers, + port, + name_prefix) + else: + common_args['worker_hosts'] = WorkerHosts(num_workers, port, name_prefix) + common_args['ps_hosts'] = PsHosts(num_param_servers, port, name_prefix) + return common_args + + +def WorkerHosts(num_workers, port, name_prefix): + worker_hosts = ['%s-worker%d:%d' % (name_prefix, i, port) + for i in range(num_workers)] + return ','.join(worker_hosts) + + +def PsHosts(num_ps, port, name_prefix): + ps_hosts = ['%s-ps%d:%d' % (name_prefix, i, port) + for i in range(num_ps)] + return ','.join(ps_hosts) diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py new file mode 100644 index 00000000000..7d9b3f83f51 --- /dev/null +++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py @@ -0,0 +1,132 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.tools.dist_test.scripts.k8s_tensorflow_lib.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import googletest +from tensorflow.tools.dist_test.scripts import k8s_tensorflow_lib + + +class K8sTensorflowTest(googletest.TestCase): + + def testGenerateConfig_LoadBalancer(self): + # Use loadbalancer + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=True, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False) + self.assertTrue('LoadBalancer' in config) + + # Don't use loadbalancer + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=False, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False) + self.assertFalse('LoadBalancer' in config) + + def testGenerateConfig_SharedVolume(self): + # Use shared directory + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=False, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=True) + self.assertTrue('/shared' in config) + + # Don't use shared directory + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=False, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False) + self.assertFalse('/shared' in config) + + def testEnvVar(self): + # Use loadbalancer + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=True, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False, + env_vars={'test1': 'test1_value', 'test2': 'test2_value'}) + self.assertTrue('{name: "test1", value: "test1_value"}' in config) + self.assertTrue('{name: "test2", value: "test2_value"}' in config) + + def testClusterSpec(self): + # Use cluster_spec + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=True, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False, + use_cluster_spec=True) + self.assertFalse('worker_hosts' in config) + self.assertFalse('ps_hosts' in config) + self.assertTrue( + '"--cluster_spec=worker|abc-worker0:5000,ps|abc-ps0:5000"' in config) + + # Don't use cluster_spec + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=True, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False, + use_cluster_spec=False) + self.assertFalse('cluster_spec' in config) + self.assertTrue('"--worker_hosts=abc-worker0:5000"' in config) + self.assertTrue('"--ps_hosts=abc-ps0:5000"' in config) + + def testWorkerHosts(self): + self.assertEquals( + 'test_prefix-worker0:1234', + k8s_tensorflow_lib.WorkerHosts(1, 1234, 'test_prefix')) + self.assertEquals( + 'test_prefix-worker0:1234,test_prefix-worker1:1234', + k8s_tensorflow_lib.WorkerHosts(2, 1234, 'test_prefix')) + + def testPsHosts(self): + self.assertEquals( + 'test_prefix-ps0:1234,test_prefix-ps1:1234', + k8s_tensorflow_lib.PsHosts(2, 1234, 'test_prefix')) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/dist_test/server/BUILD b/tensorflow/tools/dist_test/server/BUILD index 25efc83716e..865af8dd7b2 100644 --- a/tensorflow/tools/dist_test/server/BUILD +++ b/tensorflow/tools/dist_test/server/BUILD @@ -7,7 +7,9 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -py_library( +load("//tensorflow:tensorflow.bzl", "py_test") + +py_binary( name = "grpc_tensorflow_server", srcs = [ "grpc_tensorflow_server.py", diff --git a/tensorflow/tools/dist_test/server/Dockerfile b/tensorflow/tools/dist_test/server/Dockerfile index 4b13b814e39..fabc8a7105e 100644 --- a/tensorflow/tools/dist_test/server/Dockerfile +++ b/tensorflow/tools/dist_test/server/Dockerfile @@ -17,7 +17,7 @@ # # To build the image, use ../build_server.sh -FROM ubuntu:14.04 +FROM ubuntu:16.04 MAINTAINER Shanqing Cai diff --git a/tensorflow/tools/dist_test/server/Dockerfile.test b/tensorflow/tools/dist_test/server/Dockerfile.test index e2feb2227bb..908af8af9bb 100644 --- a/tensorflow/tools/dist_test/server/Dockerfile.test +++ b/tensorflow/tools/dist_test/server/Dockerfile.test @@ -17,7 +17,7 @@ # # To build the image, use ../build_server.sh --test -FROM ubuntu:14.04 +FROM ubuntu:16.04 MAINTAINER Shanqing Cai @@ -52,13 +52,13 @@ ADD . /var/tf-k8s # Download MNIST data for tests RUN mkdir -p /tmp/mnist-data RUN curl -o /tmp/mnist-data/train-labels-idx1-ubyte.gz \ - http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz + https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz RUN curl -o /tmp/mnist-data/train-images-idx3-ubyte.gz \ - http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz + https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz RUN curl -o /tmp/mnist-data/t10k-labels-idx1-ubyte.gz \ - http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz + https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz RUN curl -o /tmp/mnist-data/t10k-images-idx3-ubyte.gz \ - http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz + https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz # Download Census data for Wide & Deep test RUN mkdir -p /tmp/census-data diff --git a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py old mode 100755 new mode 100644 index 2d774577b6d..bd6700a0b1f --- a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py +++ b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py @@ -36,6 +36,7 @@ from __future__ import print_function import argparse import sys +from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.platform import app from tensorflow.python.training import server_lib @@ -103,8 +104,11 @@ def main(unused_args): raise ValueError("Invalid task_id: %d" % FLAGS.task_id) server_def.task_index = FLAGS.task_id + config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions( + per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)) + # Create GRPC Server instance - server = server_lib.Server(server_def) + server = server_lib.Server(server_def, config=config) # join() is blocking, unlike start() server.join() @@ -137,6 +141,11 @@ if __name__ == "__main__": default=0, help="Task index, e.g., 0" ) + parser.add_argument( + "--gpu_memory_fraction", + type=float, + default=1.0, + help="Fraction of GPU memory allocated",) parser.add_argument( "--verbose", type="bool", @@ -145,5 +154,6 @@ if __name__ == "__main__": default=False, help="Verbose mode" ) + FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile index 4f00696be59..5b3f1f936a4 100644 --- a/tensorflow/tools/docker/Dockerfile +++ b/tensorflow/tools/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:14.04 +FROM ubuntu:16.04 MAINTAINER Craig Citro @@ -66,4 +66,4 @@ EXPOSE 8888 WORKDIR "/notebooks" -CMD ["/run_jupyter.sh"] +CMD ["/run_jupyter.sh", "--allow-root"] diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index 8cd6ee6f331..38a67f80aae 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -1,4 +1,4 @@ -FROM ubuntu:14.04 +FROM ubuntu:16.04 MAINTAINER Craig Citro @@ -17,6 +17,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ unzip \ zip \ zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -46,34 +48,21 @@ COPY run_jupyter.sh / # Set up Bazel. -# We need to add a custom PPA to pick up JDK8, since trusty doesn't -# have an openjdk8 backport. openjdk-r is maintained by a reliable contributor: -# Matthias Klose (https://launchpad.net/~doko). It will do until -# we either update the base image beyond 14.04 or openjdk-8 is -# finally backported to trusty; see e.g. -# https://bugs.launchpad.net/trusty-backports/+bug/1368094 -RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - apt-get update && \ - apt-get install -y --no-install-recommends openjdk-8-jdk openjdk-8-jre-headless && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - # Running bazel inside a `docker build` command causes trouble, cf: # https://github.com/bazelbuild/bazel/issues/134 # The easiest solution is to set up a bazelrc file forcing --batch. -RUN echo "startup --batch" >>/root/.bazelrc +RUN echo "startup --batch" >>/etc/bazel.bazelrc # Similarly, we need to workaround sandboxing issues: # https://github.com/bazelbuild/bazel/issues/418 RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ - >>/root/.bazelrc -ENV BAZELRC /root/.bazelrc + >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.4.2 +ENV BAZEL_VERSION 0.5.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ - curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ - curl -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE.txt && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ chmod +x bazel-*.sh && \ ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ cd / && \ @@ -83,7 +72,7 @@ RUN mkdir /bazel && \ RUN git clone https://github.com/tensorflow/tensorflow.git && \ cd tensorflow && \ - git checkout r1.0 + git checkout r1.2 WORKDIR /tensorflow # TODO(craigcitro): Don't install the pip package, since it makes it @@ -93,7 +82,8 @@ WORKDIR /tensorflow ENV CI_BUILD_PYTHON python RUN tensorflow/tools/ci_build/builds/configured CPU \ - bazel build -c opt tensorflow/tools/pip_package:build_pip_package && \ + bazel build -c opt --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \ + tensorflow/tools/pip_package:build_pip_package && \ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \ pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl && \ rm -rf /tmp/pip && \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index a3ccf919179..d0a038a9db6 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -1,4 +1,4 @@ -FROM nvidia/cuda:8.0-cudnn5-devel +FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 MAINTAINER Craig Citro @@ -17,6 +17,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ unzip \ zip \ zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -46,34 +48,21 @@ COPY run_jupyter.sh / # Set up Bazel. -# We need to add a custom PPA to pick up JDK8, since trusty doesn't -# have an openjdk8 backport. openjdk-r is maintained by a reliable contributor: -# Matthias Klose (https://launchpad.net/~doko). It will do until -# we either update the base image beyond 14.04 or openjdk-8 is -# finally backported to trusty; see e.g. -# https://bugs.launchpad.net/trusty-backports/+bug/1368094 -RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - apt-get update && \ - apt-get install -y --no-install-recommends openjdk-8-jdk openjdk-8-jre-headless && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - # Running bazel inside a `docker build` command causes trouble, cf: # https://github.com/bazelbuild/bazel/issues/134 # The easiest solution is to set up a bazelrc file forcing --batch. -RUN echo "startup --batch" >>/root/.bazelrc +RUN echo "startup --batch" >>/etc/bazel.bazelrc # Similarly, we need to workaround sandboxing issues: # https://github.com/bazelbuild/bazel/issues/418 RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ - >>/root/.bazelrc -ENV BAZELRC /root/.bazelrc + >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.4.2 +ENV BAZEL_VERSION 0.5.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ - curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ - curl -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE.txt && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ chmod +x bazel-*.sh && \ ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ cd / && \ @@ -83,7 +72,7 @@ RUN mkdir /bazel && \ RUN git clone https://github.com/tensorflow/tensorflow.git && \ cd tensorflow && \ - git checkout r1.0 + git checkout r1.2 WORKDIR /tensorflow # Configure the build for our CUDA configuration. @@ -93,7 +82,8 @@ ENV TF_NEED_CUDA 1 ENV TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1 RUN tensorflow/tools/ci_build/builds/configured GPU \ - bazel build -c opt --config=cuda tensorflow/tools/pip_package:build_pip_package && \ + bazel build -c opt --config=cuda --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \ + tensorflow/tools/pip_package:build_pip_package && \ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \ pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl && \ rm -rf /tmp/pip && \ diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu index 77113c1d828..3ba1e963f92 100644 --- a/tensorflow/tools/docker/Dockerfile.gpu +++ b/tensorflow/tools/docker/Dockerfile.gpu @@ -1,4 +1,4 @@ -FROM nvidia/cuda:8.0-cudnn5-devel +FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 MAINTAINER Craig Citro @@ -69,4 +69,4 @@ EXPOSE 8888 WORKDIR "/notebooks" -CMD ["/run_jupyter.sh"] +CMD ["/run_jupyter.sh", "--allow-root"] diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md index 77fd8fc0d4f..6d5a9bdc4ce 100644 --- a/tensorflow/tools/docker/README.md +++ b/tensorflow/tools/docker/README.md @@ -10,16 +10,16 @@ General installation instructions are quick links here: * [OSX](https://www.docker.com/products/docker#/mac) -* [ubuntu](https://docs.docker.com/engine/installation/linux/ubuntulinux/) +* [Ubuntu](https://docs.docker.com/engine/installation/linux/ubuntulinux/) ## Which containers exist? -We currently maintain three Docker container images: +We currently maintain two Docker container images: * `gcr.io/tensorflow/tensorflow` - TensorFlow with all dependencies - CPU only! * `gcr.io/tensorflow/tensorflow:latest-gpu` - TensorFlow with all dependencies - and support for Nvidia Cuda + and support for NVidia CUDA Note: We also publish the same containers into [Docker Hub](https://hub.docker.com/r/tensorflow/tensorflow/tags/). @@ -37,9 +37,9 @@ For GPU support install NVidia drivers (ideally latest) and $ nvidia-docker run -it -p 8888:8888 gcr.io/tensorflow/tensorflow:latest-gpu -Note: If you would have a problem running nvidia-docker you may try the old way -we have used. But it is not recomended. If you find a bug in nvidia-docker report -it there please and try using the nvidia-docker as described above. +Note: If you would have a problem running nvidia-docker you may try the old method +we have used. But it is not recommended. If you find a bug in nvidia-docker, please report +it there and try using nvidia-docker as described above. $ export CUDA_SO=$(\ls /usr/lib/x86_64-linux-gnu/libcuda.* | xargs -I{} echo '-v {}:{}') $ export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}') @@ -49,11 +49,35 @@ it there please and try using the nvidia-docker as described above. ## More containers See all available [tags](https://hub.docker.com/r/tensorflow/tensorflow/tags/) -for additional containers like release candidates or nighlty builds. +for additional containers, such as release candidates or nightly builds. ## Rebuilding the containers -Just pick the dockerfile corresponding to the container you want to build, and run +Building TensorFlow Docker containers should be done through the +[parameterized_docker_build.sh](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/README.md) +script. The raw Dockerfiles should not be used directly as they contain strings +to be replaced by the script during the build. - $ docker build --pull -t $USER/tensorflow-suffix -f Dockerfile.suffix . +To use the script, specify the container type (`CPU` vs. `GPU`), the desired +Python version (`PYTHON2` vs. `PYTHON3`) and whether the developer Docker image +is to be built (`NO` vs. `YES`). In addition, you need to specify the central +location from where the pip package of TensorFlow will be downloaded. + +For example, to build a CPU-only non-developer Docker image for Python 2, using +TensorFlow's nightly pip package: + +``` bash +export TF_DOCKER_BUILD_IS_DEVEL=NO +export TF_DOCKER_BUILD_TYPE=CPU +export TF_DOCKER_BUILD_PYTHON_VERSION=PYTHON2 + +export NIGHTLY_VERSION="1.head" +export TF_DOCKER_BUILD_CENTRAL_PIP=$(echo ${TF_DOCKER_BUILD_PYTHON_VERSION} | sed s^PYTHON2^http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=${TF_DOCKER_BUILD_PYTHON_VERSION},label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-${NIGHTLY_VERSION}-cp27-cp27mu-manylinux1_x86_64.whl^ | sed s^PYTHON3^http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-${NIGHTLY_VERSION}-cp35-cp35m-manylinux1_x86_64.whl^) + +tensorflow/tools/docker/parameterized_docker_build.sh +``` + +If successful, the image will be tagged as `${USER}/tensorflow:latest` by default. + +Rebuilding GPU images requires [nvidia-docker](https://github.com/NVIDIA/nvidia-docker). diff --git a/tensorflow/tools/docker/jupyter_notebook_config.py b/tensorflow/tools/docker/jupyter_notebook_config.py index 6b1ebc3ee0a..747beb8251e 100644 --- a/tensorflow/tools/docker/jupyter_notebook_config.py +++ b/tensorflow/tools/docker/jupyter_notebook_config.py @@ -22,5 +22,10 @@ c.MultiKernelManager.default_kernel_name = 'python2' # sets a password if PASSWORD is set in the environment if 'PASSWORD' in os.environ: - c.NotebookApp.password = passwd(os.environ['PASSWORD']) + password = os.environ['PASSWORD'] + if password: + c.NotebookApp.password = passwd(password) + else: + c.NotebookApp.password = '' + c.NotebookApp.token = '' del os.environ['PASSWORD'] diff --git a/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb b/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb index c0b9f10b2eb..0633b03259a 100644 --- a/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb +++ b/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb @@ -72,7 +72,6 @@ }, "outputs": [ { - "metadata": {}, "name": "stdout", "output_type": "stream", "text": [ @@ -136,7 +135,6 @@ }, "outputs": [ { - "metadata": {}, "name": "stdout", "output_type": "stream", "text": [ @@ -181,7 +179,6 @@ }, "outputs": [ { - "metadata": {}, "name": "stdout", "output_type": "stream", "text": [ @@ -278,7 +275,6 @@ }, "outputs": [ { - "metadata": {}, "name": "stdout", "output_type": "stream", "text": [ @@ -343,7 +339,6 @@ }, "outputs": [ { - "metadata": {}, "name": "stdout", "output_type": "stream", "text": [ @@ -425,7 +420,6 @@ }, "outputs": [ { - "metadata": {}, "name": "stdout", "output_type": "stream", "text": [ @@ -512,7 +506,6 @@ }, "outputs": [ { - "metadata": {}, "name": "stdout", "output_type": "stream", "text": [ @@ -604,7 +597,6 @@ }, "outputs": [ { - "metadata": {}, "name": "stdout", "output_type": "stream", "text": [ diff --git a/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb b/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb index b35b14df1fd..c9f2b1ab9ef 100644 --- a/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb +++ b/tensorflow/tools/docker/notebooks/3_mnist_from_scratch.ipynb @@ -134,7 +134,7 @@ "import os\n", "from six.moves.urllib.request import urlretrieve\n", "\n", - "SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'\n", + "SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'\n", "WORK_DIRECTORY = \"/tmp/mnist-data\"\n", "\n", "def maybe_download(filename):\n", diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh index 35c12184700..ea88d8165f4 100755 --- a/tensorflow/tools/docker/parameterized_docker_build.sh +++ b/tensorflow/tools/docker/parameterized_docker_build.sh @@ -64,7 +64,7 @@ # # TF_DOCKER_BUILD_OPTIONS # (Optional) -# Specifices the desired build options. Defaults to OPT. +# Specifies the desired build options. Defaults to OPT. # Script directory SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -233,13 +233,16 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Modify python/pip version if necessary. if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then - sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ + if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \ sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ - sed -i -e 's^# RUN ln -s /usr/bin/python3 /usr/bin/python#^RUN ln -s /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" && \ - echo "Modified Dockerfile for python version "\ -"${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" || \ - die "FAILED to modify ${DOCKERFILE} for python3" + sed -i -e 's^# RUN ln -s /usr/bin/python3 /usr/bin/python#^RUN ln -s /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" + then + echo "Modified Dockerfile for python version "\ +"${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + else + die "FAILED to modify ${DOCKERFILE} for python3" + fi fi else DOCKERFILE="${TMP_DIR}/Dockerfile" @@ -250,14 +253,17 @@ else # Modify python/pip version if necessary. if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then - sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \ + if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \ sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \ sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \ - sed -i -e 's^# RUN ln -s /usr/bin/python3 /usr/bin/python#^RUN ln -s /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" && \ - echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" || \ - die "FAILED to modify ${DOCKERFILE} for python3" + sed -i -e 's^# RUN ln -s /usr/bin/python3 /usr/bin/python#^RUN ln -s /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" + then + echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + else + die "FAILED to modify ${DOCKERFILE} for python3" + fi fi fi @@ -266,7 +272,7 @@ fi IMG="${USER}/tensorflow:${FINAL_TAG}" echo "Building docker image with image name and tag: ${IMG}" -"${DOCKER_BINARY}" build --no-cache -t "${IMG}" -f "${DOCKERFILE}" "${TMP_DIR}" +"${DOCKER_BINARY}" build --no-cache --pull -t "${IMG}" -f "${DOCKERFILE}" "${TMP_DIR}" if [[ $? == "0" ]]; then echo "${DOCKER_BINARY} build of ${IMG} succeeded" else @@ -277,7 +283,7 @@ fi # Make sure that there is no other containers of the same image running # TODO(cais): Move to an earlier place. -if [[ ! -z $("${DOCKER_BINARY}" ps | grep "${IMG}") ]]; then +if "${DOCKER_BINARY}" ps | grep -q "${IMG}"; then die "ERROR: It appears that there are docker containers of the image "\ "${IMG} running. Please stop them before proceeding" fi @@ -310,16 +316,22 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # on the running docker container echo "" echo "Performing basic sanity checks on the running container..." - wget -qO- "http://127.0.0.1:${CONTAINER_PORT}/tree" &> /dev/null && \ - echo " PASS: wget tree" || \ - mark_check_failed " FAIL: wget tree" + if wget -qO- "http://127.0.0.1:${CONTAINER_PORT}/tree" &> /dev/null + then + echo " PASS: wget tree" + else + mark_check_failed " FAIL: wget tree" + fi for NB in ${TMP_DIR}/notebooks/*.ipynb; do NB_BASENAME=$(basename "${NB}") NB_URL="http://127.0.0.1:${CONTAINER_PORT}/notebooks/${NB_BASENAME}" - wget -qO- "${NB_URL}" -o "${TMP_DIR}/${NB_BASENAME}" &> /dev/null && \ - echo " PASS: wget ${NB_URL}" || \ - mark_check_failed " FAIL: wget ${NB_URL}" + if wget -qO- "${NB_URL}" -o "${TMP_DIR}/${NB_BASENAME}" &> /dev/null + then + echo " PASS: wget ${NB_URL}" + else + mark_check_failed " FAIL: wget ${NB_URL}" + fi done fi diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index f321354eb58..8e27b133c2f 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -9,12 +9,7 @@ package( default_visibility = ["//tensorflow:__subpackages__"], ) -py_binary( - name = "gen_cc_md", - srcs = ["gen_cc_md.py"], - srcs_version = "PY2AND3", - deps = ["//tensorflow:tensorflow_py"], -) +load("//tensorflow:tensorflow.bzl", "py_test") py_library( name = "doc_generator_visitor", @@ -39,87 +34,110 @@ py_test( py_library( name = "parser", - srcs = [ - "parser.py", - ], + srcs = ["parser.py"], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], ) py_test( name = "parser_test", size = "small", - srcs = [ - "parser_test.py", - ], + srcs = ["parser_test.py"], srcs_version = "PY2AND3", + tags = ["manual"], deps = [ ":parser", "//tensorflow/python:platform_test", ], ) +py_library( + name = "pretty_docs", + srcs = ["pretty_docs.py"], + srcs_version = "PY2AND3", +) + +py_binary( + name = "generate_lib", + srcs = ["generate_lib.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":doc_generator_visitor", + ":parser", + ":pretty_docs", + ":py_guide_parser", + "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", + "//tensorflow/tools/common:public_api", + "//tensorflow/tools/common:traverse", + ], +) + +py_test( + name = "generate_lib_test", + size = "small", + srcs = ["generate_lib_test.py"], + srcs_version = "PY2AND3", + tags = ["manual"], + deps = [ + ":generate_lib", + ":parser", + "//tensorflow:tensorflow_py", + "//tensorflow/python:platform_test", + "//tensorflow/python/debug:debug_py", + ], +) + py_binary( name = "generate", srcs = ["generate.py"], srcs_version = "PY2AND3", deps = [ + ":generate_lib", "//tensorflow:tensorflow_py", - "//tensorflow/tools/common:public_api", - "//tensorflow/tools/common:traverse", - "//tensorflow/tools/docs:doc_generator_visitor", - "//tensorflow/tools/docs:parser", + "//tensorflow/python/debug:debug_py", ], ) py_test( - name = "generate_test", + name = "build_docs_test", size = "small", - srcs = [ - "generate_test.py", - ], + srcs = ["build_docs_test.py"], + data = ["//tensorflow:docs_src"], srcs_version = "PY2AND3", tags = ["manual"], deps = [ - ":generate", - "//tensorflow/python:platform_test", + ":generate_lib", + "//tensorflow:tensorflow_py", + "//tensorflow/python/debug:debug_py", ], ) py_binary( - name = "make_py_guides", - srcs = ["make_py_guides.py"], + name = "generate_1_0", + srcs = ["generate_1_0.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/tools/docs:generate", - "//tensorflow/tools/docs:parser", + ":generate_lib", + "//tensorflow:tensorflow_py", + "//tensorflow/python/debug:debug_py", ], ) -filegroup( - name = "doxy_config", - srcs = ["tf-doxy_for_md-config"], +py_library( + name = "py_guide_parser", + srcs = ["py_guide_parser.py"], + srcs_version = "PY2AND3", ) -sh_binary( - name = "gen_docs", - srcs = ["gen_docs.sh"], - data = [ - ":doxy_config", - ":gen_cc_md", - "//tensorflow/python:gen_docs_combined", - ], -) - -sh_test( - name = "gen_docs_test", +py_test( + name = "py_guide_parser_test", size = "small", - srcs = [ - "gen_docs_test.sh", - ], - data = [ - ":gen_docs", - "//tensorflow/core:all_files", - "//tensorflow/python:all_files", + srcs = ["py_guide_parser_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":py_guide_parser", + "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/tools/docs/build_docs_test.py b/tensorflow/tools/docs/build_docs_test.py new file mode 100644 index 00000000000..d28dd93b9a8 --- /dev/null +++ b/tensorflow/tools/docs/build_docs_test.py @@ -0,0 +1,51 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run the python doc generator and fail if there are any broken links.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf +from tensorflow.python import debug as tf_debug +from tensorflow.python.platform import googletest +from tensorflow.python.platform import resource_loader +from tensorflow.tools.docs import generate_lib + + +class Flags(object): + resource_root = resource_loader.get_root_dir_with_all_resources() + src_dir = os.path.join(resource_root, 'third_party/tensorflow/docs_src') + base_dir = os.path.join(resource_root, 'third_party/tensorflow/') + output_dir = googletest.GetTempDir() + + +class BuildDocsTest(googletest.TestCase): + + def testBuildDocs(self): + doc_generator = generate_lib.DocGenerator() + + doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)]) + + status = doc_generator.build(Flags()) + + if status: + self.fail('Found %s Errors!' % status) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/docs/doc_generator_visitor.py b/tensorflow/tools/docs/doc_generator_visitor.py index d4ff33a0726..259a4694fdc 100644 --- a/tensorflow/tools/docs/doc_generator_visitor.py +++ b/tensorflow/tools/docs/doc_generator_visitor.py @@ -18,17 +18,36 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect - import six +from tensorflow.python.util import tf_inspect + class DocGeneratorVisitor(object): """A visitor that generates docs for a python object when __call__ed.""" - def __init__(self): + def __init__(self, root_name=''): + """Make a visitor. + + As this visitor is starting its traversal at a module or class, it will not + be told the name of that object during traversal. `root_name` is the name it + should use for that object, effectively prefixing all names with + "root_name.". + + Args: + root_name: The name of the root module/class. + """ + self.set_root_name(root_name) self._index = {} self._tree = {} + self._reverse_index = None + self._duplicates = None + self._duplicate_of = None + + def set_root_name(self, root_name): + """Sets the root name for subsequent __call__s.""" + self._root_name = root_name or '' + self._prefix = (root_name + '.') if root_name else '' @property def index(self): @@ -53,6 +72,56 @@ class DocGeneratorVisitor(object): """ return self._tree + @property + def reverse_index(self): + """A map from `id(object)` to the preferred fully qualified name. + + This map only contains non-primitive objects (no numbers or strings) present + in `index` (for primitive objects, `id()` doesn't quite do the right thing). + + It is computed when it, `duplicate_of`, or `duplicates` are first accessed. + + Returns: + The `id(object)` to full name map. + """ + self._maybe_find_duplicates() + return self._reverse_index + + @property + def duplicate_of(self): + """A map from duplicate full names to a preferred fully qualified name. + + This map only contains names that are not themself a preferred name. + + It is computed when it, `reverse_index`, or `duplicates` are first accessed. + + Returns: + The map from duplicate name to preferred name. + """ + self._maybe_find_duplicates() + return self._duplicate_of + + @property + def duplicates(self): + """A map from preferred full names to a list of all names for this symbol. + + This function returns a map from preferred (master) name for a symbol to a + lexicographically sorted list of all aliases for that name (incl. the master + name). Symbols without duplicate names do not appear in this map. + + It is computed when it, `reverse_index`, or `duplicate_of` are first + accessed. + + Returns: + The map from master name to list of all duplicate names. + """ + self._maybe_find_duplicates() + return self._duplicates + + def _add_prefix(self, name): + """Adds the root name to a name.""" + return self._prefix + name if name else self._root_name + def __call__(self, parent_name, parent, children): """Visitor interface, see `tensorflow/tools/common:traverse` for details. @@ -64,42 +133,48 @@ class DocGeneratorVisitor(object): parent_name: The fully qualified name of a symbol found during traversal. parent: The Python object referenced by `parent_name`. children: A list of `(name, py_object)` pairs enumerating, in alphabetical - order, the children (as determined by `inspect.getmembers`) of `parent`. - `name` is the local name of `py_object` in `parent`. + order, the children (as determined by `tf_inspect.getmembers`) of + `parent`. `name` is the local name of `py_object` in `parent`. Raises: RuntimeError: If this visitor is called with a `parent` that is not a class or module. """ + parent_name = self._add_prefix(parent_name) self._index[parent_name] = parent self._tree[parent_name] = [] - if inspect.ismodule(parent): - print('module %s: %r' % (parent_name, parent)) - elif inspect.isclass(parent): - print('class %s: %r' % (parent_name, parent)) - else: - raise RuntimeError('Unexpected type in visitor -- %s: %r' % - (parent_name, parent)) + if not (tf_inspect.ismodule(parent) or tf_inspect.isclass(parent)): + raise RuntimeError('Unexpected type in visitor -- %s: %r' % (parent_name, + parent)) + + for i, (name, child) in enumerate(list(children)): + # Don't document __metaclass__ + if name in ['__metaclass__']: + del children[i] + continue - for name, child in children: full_name = '.'.join([parent_name, name]) if parent_name else name self._index[full_name] = child self._tree[parent_name].append(name) - def find_duplicates(self): + def _maybe_find_duplicates(self): """Compute data structures containing information about duplicates. Find duplicates in `index` and decide on one to be the "master" name. - Returns a map `duplicate_of` from aliases to their master name (the master - name itself has no entry in this map), and a map `duplicates` from master - names to a lexicographically sorted list of all aliases for that name (incl. - the master name). + Computes a reverse_index mapping each object id to its master name. - Returns: - A tuple `(duplicate_of, duplicates)` as described above. + Also computes a map `duplicate_of` from aliases to their master name (the + master name itself has no entry in this map), and a map `duplicates` from + master names to a lexicographically sorted list of all aliases for that name + (incl. the master name). + + All these are computed and set as fields if they haven't already. """ + if self._reverse_index is not None: + return + # Maps the id of a symbol to its fully qualified name. For symbols that have # several aliases, this map contains the first one found. # We use id(py_object) to get a hashable value for py_object. Note all @@ -110,15 +185,13 @@ class DocGeneratorVisitor(object): # maps the first name found to a list of all duplicate names. raw_duplicates = {} for full_name, py_object in six.iteritems(self._index): - # We cannot use the duplicate mechanism for constants, since e.g., + # We cannot use the duplicate mechanism for some constants, since e.g., # id(c1) == id(c2) with c1=1, c2=1. This is unproblematic since constants # have no usable docstring and won't be documented automatically. - if (inspect.ismodule(py_object) or - inspect.isclass(py_object) or - inspect.isfunction(py_object) or - inspect.isroutine(py_object) or - inspect.ismethod(py_object) or - isinstance(py_object, property)): + if (py_object is not None and + not isinstance(py_object, six.integer_types + six.string_types + + (six.binary_type, six.text_type, float, complex, bool)) + and py_object is not ()): object_id = id(py_object) if object_id in reverse_index: master_name = reverse_index[object_id] @@ -148,4 +221,9 @@ class DocGeneratorVisitor(object): if name != master_name: duplicate_of[name] = master_name - return duplicate_of, duplicates + # Set the reverse index to the canonical name. + reverse_index[id(self._index[master_name])] = master_name + + self._duplicate_of = duplicate_of + self._duplicates = duplicates + self._reverse_index = reverse_index diff --git a/tensorflow/tools/docs/doc_generator_visitor_test.py b/tensorflow/tools/docs/doc_generator_visitor_test.py index bbaa1c6474c..cf5be45f40e 100644 --- a/tensorflow/tools/docs/doc_generator_visitor_test.py +++ b/tensorflow/tools/docs/doc_generator_visitor_test.py @@ -75,8 +75,6 @@ class DocGeneratorVisitorTest(googletest.TestCase): [('index', doc_generator_visitor.DocGeneratorVisitor.index), ('index2', doc_generator_visitor.DocGeneratorVisitor.index)]) - duplicate_of, duplicates = visitor.find_duplicates() - # The shorter path should be master, or if equal, the lexicographically # first will be. self.assertEqual( @@ -91,7 +89,7 @@ class DocGeneratorVisitorTest(googletest.TestCase): 'DocGeneratorVisitor2.index', 'DocGeneratorVisitor2.index2' ]), - }, duplicates) + }, visitor.duplicates) self.assertEqual({ 'submodule.DocGeneratorVisitor': 'DocGeneratorVisitor2', 'submodule.DocGeneratorVisitor.index': 'DocGeneratorVisitor2.index', @@ -100,8 +98,12 @@ class DocGeneratorVisitorTest(googletest.TestCase): 'submodule2.DocGeneratorVisitor.index': 'DocGeneratorVisitor2.index', 'submodule2.DocGeneratorVisitor.index2': 'DocGeneratorVisitor2.index', 'DocGeneratorVisitor2.index2': 'DocGeneratorVisitor2.index' - }, duplicate_of) - + }, visitor.duplicate_of) + self.assertEqual({ + id(doc_generator_visitor.DocGeneratorVisitor): 'DocGeneratorVisitor2', + id(doc_generator_visitor.DocGeneratorVisitor.index): + 'DocGeneratorVisitor2.index', + }, visitor.reverse_index) if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/docs/gen_cc_md.py b/tensorflow/tools/docs/gen_cc_md.py deleted file mode 100644 index 931df3230b4..00000000000 --- a/tensorflow/tools/docs/gen_cc_md.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Convert Doxygen .xml files to MarkDown (.md files).""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import os -import re - -from BeautifulSoup import BeautifulStoneSoup -import tensorflow as tf - -ANCHOR_RE = re.compile(r'\W+') - -PAGE_TEMPLATE = '''# `{0} {1}` - -{2} - -###Member Details - -{3}''' - -INDEX_TEMPLATE = '''# TensorFlow C++ Session API reference documentation - -TensorFlow's public C++ API includes only the API for executing graphs, as of -version 0.5. To control the execution of a graph from C++: - -1. Build the computation graph using the [Python API](../python/). -1. Use [`tf.train.write_graph()`](../python/train.md#write_graph) to -write the graph to a file. -1. Load the graph using the C++ Session API. For example: - - ```c++ - // Reads a model graph definition from disk, and creates a session object you - // can use to run it. - Status LoadGraph(string graph_file_name, Session** session) { - GraphDef graph_def; - TF_RETURN_IF_ERROR( - ReadBinaryProto(Env::Default(), graph_file_name, &graph_def)); - TF_RETURN_IF_ERROR(NewSession(SessionOptions(), session)); - TF_RETURN_IF_ERROR((*session)->Create(graph_def)); - return Status::OK(); - } -``` - -1. Run the graph with a call to `session->Run()` - -## Env - -@@Env -@@RandomAccessFile -@@WritableFile -@@EnvWrapper - -## Session - -@@Session -@@SessionOptions - -## Status - -@@Status -@@Status::State - -## Tensor - -@@Tensor -@@TensorShape -@@TensorShapeDim -@@TensorShapeUtils -@@PartialTensorShape -@@PartialTensorShapeUtils - -## Thread - -@@Thread -@@ThreadOptions -''' - -FLAGS = None - - -def member_definition(member_elt): - def_text = '' - - def_elt = member_elt.find('definition') - if def_elt: - def_text = def_elt.text - - return def_text - - -def member_sig(member_elt): - def_text = member_definition(member_elt) - - argstring_text = '' - argstring = member_elt.find('argsstring') - if argstring: - argstring_text = argstring.text - - sig = def_text + argstring_text - return sig - - -def anchorize(name): - return ANCHOR_RE.sub('_', name) - - -def element_text(member_elt, elt_name): - """Extract all `para` text from (`elt_name` in) `member_elt`.""" - text = [] - if elt_name: - elt = member_elt.find(elt_name) - else: - elt = member_elt - - if elt: - paras = elt.findAll('para') - for p in paras: - text.append(p.getText(separator=u' ').strip()) - return '\n\n'.join(text) - - -def full_member_entry(member_elt): - """Generate the description of `member_elt` for "Member Details".""" - anchor = '{#' + anchorize(member_definition(member_elt)) + '}' - full_entry = '#### `%s` %s' % (member_sig(member_elt), anchor) - - complete_descr = element_text(member_elt, 'briefdescription') + '\n\n' - complete_descr += element_text(member_elt, 'detaileddescription') - - if complete_descr: - full_entry += '\n\n' + complete_descr - - return full_entry - - -def brief_member_entry(member_elt): - """Generate the description of `member_elt` for the "Member Summary".""" - brief_item = '' - brief_descr = element_text(member_elt, 'briefdescription') - if brief_descr: - brief_item = '\n * ' + brief_descr - sig = member_sig(member_elt) - memdef = member_definition(member_elt) - linkified_sig = '[`{0}`](#{1})'.format(sig, anchorize(memdef)) - - return '* ' + linkified_sig + brief_item - - -def all_briefs(members): - briefs = [brief_member_entry(member_elt) for member_elt in members] - return '\n'.join(briefs) - - -def all_fulls(members): - fulls = [full_member_entry(member_elt) for member_elt in members] - return '\n\n'.join(fulls) - - -def page_overview(class_elt): - """Returns the contents of the .md file for `class_elt`.""" - overview_brief = '' - overview_details = '' - - briefs = class_elt.findAll('briefdescription', recursive=False) - if briefs: - overview_brief = element_text(briefs[0], None) - - details = class_elt.findAll('detaileddescription', recursive=False) - if details: - overview_details = element_text(details[0], None) - - return overview_brief + '\n\n' + overview_details - - -def page_with_name(pages, name): - def match(n): - for i in xrange(len(pages)): - if pages[i].get_name() == n: - return i - return None - return match(name) or match('tensorflow::' + name) - - -def get_all_indexed_pages(): - all_pages = set() - lines = INDEX_TEMPLATE.split('\n') - for i in range(len(lines)): - if lines[i].startswith('@@'): - name = lines[i][2:] - all_pages.add(name) - return all_pages - - -def index_page(pages): - """Create the index page linking to `pages` using INDEX_TEMPLATE.""" - pages = pages[:] - lines = INDEX_TEMPLATE.split('\n') - all_md_files = [] - for i in range(len(lines)): - if lines[i].startswith('@@'): - name = lines[i][2:] - page_index = page_with_name(pages, name) - if page_index is None: - raise ValueError('Missing page with name: ' + name) - lines[i] = '* [{0}]({1})'.format( - pages[page_index].get_name(), pages[page_index].get_md_filename()) - all_md_files.append(pages[page_index].get_md_filename()) - pages.pop(page_index) - - return '\n'.join(lines) - - -def page_in_name_list(page, names): - for name in names: - if page.get_name() == name or page.get_name() == 'tensorflow::' + name: - return True - return False - - -class Page(object): - """Holds the MarkDown converted contents of a .xml page.""" - - def __init__(self, xml_path, deftype): - self.type = deftype - xml_file = open(xml_path) - xml = xml_file.read() - xml = xml.replace('', '`').replace('', '`') - # TODO(josh11b): Should not use HTML entities inside ```...```. - soup = BeautifulStoneSoup( - xml, convertEntities=BeautifulStoneSoup.HTML_ENTITIES) - self.name = soup.find('compoundname').text - print('Making page with name ' + self.name + ' (from ' + xml_path + ')') - members = soup('memberdef', prot='public') - fulls = all_fulls(members) - self.overview = page_overview(soup.find('compounddef')) - self.page_text = PAGE_TEMPLATE.format( - self.type, self.name, self.overview, fulls) - - def get_text(self): - return self.page_text - - def get_name(self): - return self.name - - def get_short_name(self): - parse = self.get_name().split('::') - return parse[len(parse)-1] - - def get_type(self): - return self.type - - def get_md_filename(self): - capitalized_type = self.get_type()[0].upper() + self.get_type()[1:] - return capitalized_type + anchorize(self.get_short_name()) + '.md' - - -def main(unused_argv): - print('Converting in ' + FLAGS.src_dir) - pages = [] - all_pages = get_all_indexed_pages() - xml_files = os.listdir(FLAGS.src_dir) - for fname in xml_files: - if len(fname) < 6: continue - newpage = None - if fname[0:5] == 'class': - newpage = Page(os.path.join(FLAGS.src_dir, fname), 'class') - elif fname[0:6] == 'struct': - newpage = Page(os.path.join(FLAGS.src_dir, fname), 'struct') - if newpage is not None and page_in_name_list(newpage, all_pages): - pages.append(newpage) - md_filename = newpage.get_md_filename() - print('Writing ' + md_filename) - md_file = open(os.path.join(FLAGS.out_dir, md_filename), 'w') - print(newpage.get_text(), file=md_file) - - index_text = index_page(pages) - index_md_file = open(os.path.join(FLAGS.out_dir, 'index.md'), 'w') - print(index_text, file=index_md_file) - return 0 - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--src_dir', - type=str, - default=None, - help='Directory containing the doxygen output.' - ) - parser.add_argument( - '--out_dir', - type=str, - default=None, - help='Directory to which docs should be written.' - ) - FLAGS = parser.parse_args() - - tf.app.run() diff --git a/tensorflow/tools/docs/gen_docs.sh b/tensorflow/tools/docs/gen_docs.sh deleted file mode 100755 index 4f529270ab4..00000000000 --- a/tensorflow/tools/docs/gen_docs.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# This script needs to be run from the tensorflow/tools/docs directory -# Pass -a to also rebuild C++ docs. This requires doxygen. - -set -e - -DOC_DIR="g3doc/api_docs" -DOXYGEN_BIN=${DOXYGEN:-doxygen} -DOXYGEN_CONFIG="tools/docs/tf-doxy_for_md-config" -# The TMP_DIR is set inside DOXYGEN_CONFIG and cannot be changed independently -TMP_DIR=/tmp/tensorflow-docs/xml - -if [ ! -f gen_docs.sh ]; then - echo "This script must be run from inside the tensorflow/tools/docs directory." - exit 1 -fi - -# go to the tensorflow/ directory -pushd ../.. -BASE=$(pwd) - -# Make Python docs -bazel run -- //tensorflow/python:gen_docs_combined \ - --out_dir=$BASE/$DOC_DIR/python - -# Check if we should build c++ docs (if -a is given) -if [ x$1 == x-a ]; then - mkdir -p $TMP_DIR - $DOXYGEN_BIN "$BASE/$DOXYGEN_CONFIG" - bazel run -- //tensorflow/tools/docs:gen_cc_md \ - --out_dir=$BASE/$DOC_DIR/cc \ - --src_dir=$TMP_DIR -fi - -popd diff --git a/tensorflow/tools/docs/gen_docs_test.sh b/tensorflow/tools/docs/gen_docs_test.sh deleted file mode 100755 index c8c1955aa06..00000000000 --- a/tensorflow/tools/docs/gen_docs_test.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -set -eux - -if [ -d $TEST_SRCDIR/org_tensorflow ]; then - TFDIR=$TEST_SRCDIR/org_tensorflow/tensorflow -else - # Support 0.2.1- runfiles. - TFDIR=$TEST_SRCDIR/tensorflow -fi -DOXYGEN=doxygen -DOXYGEN_CONFIG="tf-doxy_for_md-config" -TMP_DIR=/tmp/tensorflow-docs -mkdir -p $TMP_DIR/python -mkdir -p $TMP_DIR/xml -mkdir -p $TMP_DIR/cc - -pushd $TFDIR -python/gen_docs_combined --out_dir=$TMP_DIR/python - -# TODO(wicke): this does not work well inside the build/test jail -#$DOXYGEN "tools/docs/$DOXYGEN_CONFIG" -#tools/docs/gen_cc_md \ -# --out_dir=$TMP_DIR/cc \ -# --src_dir=$TMP_DIR/xml -popd -echo "PASS" diff --git a/tensorflow/tools/docs/generate.py b/tensorflow/tools/docs/generate.py index 8f2958d6a6a..fc93085e3e0 100644 --- a/tensorflow/tools/docs/generate.py +++ b/tensorflow/tools/docs/generate.py @@ -18,216 +18,32 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import inspect import os +import sys -import six import tensorflow as tf -from tensorflow.tools.common import public_api -from tensorflow.tools.common import traverse -from tensorflow.tools.docs import doc_generator_visitor -from tensorflow.tools.docs import parser - - -def write_docs(output_dir, base_dir, duplicate_of, duplicates, index, tree): - """Write previously extracted docs to disk. - - Write a docs page for each symbol in `index` to a tree of docs at - `output_dir`. - - Symbols with multiple aliases will have only one page written about them, - which is referenced for all aliases. `duplicate_of` and `duplicates` are used - to determine which docs pages to write. - - Args: - output_dir: Directory to write documentation markdown files to. Will be - created if it doesn't exist. - base_dir: Base directory of the code being documented. This prefix is - stripped from all file paths that are part of the documentation. - duplicate_of: A `dict` mapping fully qualified names to "master" names. This - is used to resolve "@{symbol}" references to the "master" name. - duplicates: A `dict` mapping fully qualified names to a set of all - aliases of this name. This is used to automatically generate a list of all - aliases for each name. - index: A `dict` mapping fully qualified names to the corresponding Python - objects. Used to produce docs for child objects, and to check the validity - of "@{symbol}" references. - tree: A `dict` mapping a fully qualified name to the names of all its - members. Used to populate the members section of a class or module page. - """ - # Make output_dir. - try: - if not os.path.exists(output_dir): - os.makedirs(output_dir) - except OSError as e: - print('Creating output dir "%s" failed: %s' % (output_dir, e)) - raise - - # Parse and write Markdown pages, resolving cross-links (@{symbol}). - for full_name, py_object in six.iteritems(index): - - if full_name in duplicate_of: - print('Not writing docs for %s, duplicate of %s.' % ( - full_name, duplicate_of[full_name])) - continue - - # Methods and some routines are documented only as part of their class. - if not (inspect.ismodule(py_object) or - inspect.isclass(py_object) or - inspect.isfunction(py_object)): - print('Not writing docs for %s, not a class, module, or function.' % ( - full_name)) - continue - - print('Writing docs for %s (%r).' % (full_name, py_object)) - - # Generate docs for `py_object`, resolving references. - markdown = parser.generate_markdown(full_name, py_object, - duplicate_of=duplicate_of, - duplicates=duplicates, - index=index, - tree=tree, - base_dir=base_dir) - - # TODO(deannarubin): use _tree to generate sidebar information. - - path = os.path.join(output_dir, parser.documentation_path(full_name)) - directory = os.path.dirname(path) - try: - if not os.path.exists(directory): - os.makedirs(directory) - with open(path, 'w') as f: - f.write(markdown) - except OSError as e: - print('Cannot write documentation for %s to %s: %s' % (full_name, - directory, e)) - raise - # TODO(deannarubin): write sidebar file? - - # Write a global index containing all full names with links. - with open(os.path.join(output_dir, 'full_index.md'), 'w') as f: - f.write(parser.generate_global_index('TensorFlow', 'tensorflow', - index, duplicate_of)) - - -def extract(): - """Extract docs from tf namespace and write them to disk.""" - visitor = doc_generator_visitor.DocGeneratorVisitor() - api_visitor = public_api.PublicAPIVisitor(visitor) - - # Access something in contrib so tf.contrib is properly loaded (it's hidden - # behind lazy loading) - _ = tf.contrib.__name__ - - # Exclude some libaries in contrib from the documentation altogether. - # TODO(wicke): Shrink this list. - api_visitor.do_not_descend_map.update({ - 'contrib': [ - 'compiler', - 'factorization', - 'grid_rnn', - 'labeled_tensor', - 'ndlstm', - 'quantization', - 'session_bundle', - 'slim', - 'solvers', - 'specs', - 'tensor_forest', - 'tensorboard', - 'testing', - 'tfprof', - 'training', - ], - 'contrib.bayesflow': [ - 'entropy', 'monte_carlo', - 'special_math', 'stochastic_gradient_estimators', - 'stochastic_graph', 'stochastic_tensor', - 'stochastic_variables', 'variational_inference' - ], - 'contrib.distributions': ['bijector'], - 'contrib.graph_editor': [ - 'edit', - 'match', - 'reroute', - 'subgraph', - 'transform', - 'select', - 'util' - ], - 'contrib.layers': [ - 'feature_column', - 'summaries' - ], - 'contrib.learn': [ - 'datasets', - 'head', - 'graph_actions', - 'io', - 'models', - 'monitors', - 'ops', - 'preprocessing', - 'utils', - ], - 'contrib.util': ['loader'], - }) - - traverse.traverse(tf, api_visitor) - - return visitor - - -def write(output_dir, base_dir, visitor): - """Write documentation for an index in a `DocGeneratorVisitor` to disk. - - This function will create `output_dir` if it doesn't exist, and write - the documentation contained in `visitor`. - - Args: - output_dir: The directory to write documentation to. Must not exist. - base_dir: The base dir of the library `visitor` has traversed. This is used - to compute relative paths for file references. - visitor: A `DocGeneratorVisitor` that has traversed a library located at - `base_dir`. - """ - duplicate_of, duplicates = visitor.find_duplicates() - write_docs(output_dir, os.path.abspath(base_dir), - duplicate_of, duplicates, visitor.index, visitor.tree) - +from tensorflow.python import debug as tf_debug +from tensorflow.python.util import tf_inspect +from tensorflow.tools.docs import generate_lib if __name__ == '__main__': - argument_parser = argparse.ArgumentParser() - argument_parser.add_argument( - '--output_dir', - type=str, - default=None, - required=True, - help='Directory to write docs to. Must not exist.' - ) + doc_generator = generate_lib.DocGenerator() + doc_generator.add_output_dir_argument() + doc_generator.add_src_dir_argument() # This doc generator works on the TensorFlow codebase. Since this script lives - # at tensorflow/tools/docs, we can compute the base directory (three levels - # up), which is valid unless we're trying to apply this to a different code - # base, or are moving the script around. - script_dir = os.path.dirname(inspect.getfile(inspect.currentframe())) - default_base_dir = os.path.join(script_dir, '..', '..', '..') + # at tensorflow/tools/docs, and all code is defined somewhere inside + # tensorflow/, we can compute the base directory (two levels up), which is + # valid unless we're trying to apply this to a different code base, or are + # moving the script around. + script_dir = os.path.dirname(tf_inspect.getfile(tf_inspect.currentframe())) + default_base_dir = os.path.join(script_dir, '..', '..') + doc_generator.add_base_dir_argument(default_base_dir) - argument_parser.add_argument( - '--base_dir', - type=str, - default=default_base_dir, - help=('Base directory to to strip from file names referenced in docs. ' - 'Defaults to three directories up from the location of this file.') - ) + flags = doc_generator.parse_known_args() - flags, _ = argument_parser.parse_known_args() + # tf_debug is not imported with tf, it's a separate module altogether + doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)]) - if os.path.exists(flags.output_dir): - raise RuntimeError('output_dir %s exists.\n' - 'Cowardly refusing to wipe it, please do that yourself.' - % flags.output_dir) - - write(flags.output_dir, flags.base_dir, extract()) + sys.exit(doc_generator.build(flags)) diff --git a/tensorflow/tools/docs/generate_1_0.py b/tensorflow/tools/docs/generate_1_0.py new file mode 100644 index 00000000000..cdc03fdcacf --- /dev/null +++ b/tensorflow/tools/docs/generate_1_0.py @@ -0,0 +1,93 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generate docs for the TensorFlow Python API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +import tensorflow as tf + +from tensorflow.python import debug as tf_debug +from tensorflow.python.util import tf_inspect +from tensorflow.tools.docs import generate_lib + +if __name__ == '__main__': + doc_generator = generate_lib.DocGenerator() + doc_generator.add_output_dir_argument() + doc_generator.add_src_dir_argument() + + # This doc generator works on the TensorFlow codebase. Since this script lives + # at tensorflow/tools/docs, and all code is defined somewhere inside + # tensorflow/, we can compute the base directory (two levels up), which is + # valid unless we're trying to apply this to a different code base, or are + # moving the script around. + script_dir = os.path.dirname(tf_inspect.getfile(tf_inspect.currentframe())) + default_base_dir = os.path.join(script_dir, '..', '..') + doc_generator.add_base_dir_argument(default_base_dir) + + flags = doc_generator.parse_known_args() + + # tf_debug is not imported with tf, it's a separate module altogether + doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)]) + + doc_generator.set_do_not_descend_map({ + 'tf': ['cli', 'lib', 'wrappers'], + 'tf.contrib': [ + 'compiler', + 'factorization', + 'grid_rnn', + 'labeled_tensor', + 'ndlstm', + 'quantization', + 'session_bundle', + 'slim', + 'solvers', + 'specs', + 'tensor_forest', + 'tensorboard', + 'testing', + 'training', + 'tfprof', + ], + 'tf.contrib.bayesflow': [ + 'entropy', 'monte_carlo', 'special_math', + 'stochastic_gradient_estimators', 'stochastic_graph', + 'stochastic_tensor', 'stochastic_variables', 'variational_inference' + ], + 'tf.contrib.distributions': ['bijector'], + 'tf.contrib.ffmpeg': ['ffmpeg_ops'], + 'tf.contrib.graph_editor': [ + 'edit', 'match', 'reroute', 'subgraph', 'transform', 'select', 'util' + ], + 'tf.contrib.layers': ['feature_column', 'summaries'], + 'tf.contrib.learn': [ + 'datasets', + 'head', + 'graph_actions', + 'io', + 'models', + 'monitors', + 'ops', + 'preprocessing', + 'utils', + ], + 'tf.contrib.util': ['loader'], + }) + + sys.exit(doc_generator.build(flags)) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py new file mode 100644 index 00000000000..99872e1d844 --- /dev/null +++ b/tensorflow/tools/docs/generate_lib.py @@ -0,0 +1,511 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generate docs for the TensorFlow Python API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os + +import six + +from tensorflow.python.util import tf_inspect +from tensorflow.tools.common import public_api +from tensorflow.tools.common import traverse +from tensorflow.tools.docs import doc_generator_visitor +from tensorflow.tools.docs import parser +from tensorflow.tools.docs import pretty_docs +from tensorflow.tools.docs import py_guide_parser + + +def _is_free_function(py_object, full_name, index): + """Check if input is a free function (and not a class- or static method).""" + if not tf_inspect.isfunction(py_object): + return False + + # Static methods are functions to tf_inspect (in 2.7), so check if the parent + # is a class. If there is no parent, it's not a function. + if '.' not in full_name: + return False + + parent_name = full_name.rsplit('.', 1)[0] + if tf_inspect.isclass(index[parent_name]): + return False + + return True + + +def write_docs(output_dir, parser_config, yaml_toc): + """Write previously extracted docs to disk. + + Write a docs page for each symbol included in the indices of parser_config to + a tree of docs at `output_dir`. + + Symbols with multiple aliases will have only one page written about + them, which is referenced for all aliases. + + Args: + output_dir: Directory to write documentation markdown files to. Will be + created if it doesn't exist. + parser_config: A `parser.ParserConfig` object, containing all the necessary + indices. + yaml_toc: Set to `True` to generate a "_toc.yaml" file. + + Raises: + ValueError: if `output_dir` is not an absolute path + """ + # Make output_dir. + if not os.path.isabs(output_dir): + raise ValueError( + "'output_dir' must be an absolute path.\n" + " output_dir='%s'" % output_dir) + + try: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + except OSError as e: + print('Creating output dir "%s" failed: %s' % (output_dir, e)) + raise + + # These dictionaries are used for table-of-contents generation below + # They will contain, after the for-loop below:: + # - module name(string):classes and functions the module contains(list) + module_children = {} + # - symbol name(string):pathname (string) + symbol_to_file = {} + + # Parse and write Markdown pages, resolving cross-links (@{symbol}). + for full_name, py_object in six.iteritems(parser_config.index): + + if full_name in parser_config.duplicate_of: + continue + + # Methods and some routines are documented only as part of their class. + if not (tf_inspect.ismodule(py_object) or tf_inspect.isclass(py_object) or + _is_free_function(py_object, full_name, parser_config.index)): + continue + + sitepath = os.path.join('api_docs/python', + parser.documentation_path(full_name)[:-3]) + + # For TOC, we need to store a mapping from full_name to the file + # we're generating + symbol_to_file[full_name] = sitepath + + # For a module, remember the module for the table-of-contents + if tf_inspect.ismodule(py_object): + if full_name in parser_config.tree: + module_children.setdefault(full_name, []) + + # For something else that's documented, + # figure out what module it lives in + else: + subname = str(full_name) + while True: + subname = subname[:subname.rindex('.')] + if tf_inspect.ismodule(parser_config.index[subname]): + module_children.setdefault(subname, []).append(full_name) + break + + print('Writing docs for %s (%r).' % (full_name, py_object)) + + # Generate docs for `py_object`, resolving references. + page_info = parser.docs_for_object(full_name, py_object, parser_config) + + path = os.path.join(output_dir, parser.documentation_path(full_name)) + directory = os.path.dirname(path) + try: + if not os.path.exists(directory): + os.makedirs(directory) + with open(path, 'w') as f: + f.write(pretty_docs.build_md_page(page_info)) + except OSError as e: + print('Cannot write documentation for %s to %s: %s' % (full_name, + directory, e)) + raise + + if yaml_toc: + # Generate table of contents + + # Put modules in alphabetical order, case-insensitive + modules = sorted(module_children.keys(), key=lambda a: a.upper()) + + leftnav_path = os.path.join(output_dir, '_toc.yaml') + with open(leftnav_path, 'w') as f: + + # Generate header + f.write('# Automatically generated file; please do not edit\ntoc:\n') + for module in modules: + f.write(' - title: ' + module + '\n' + ' section:\n' + ' - title: Overview\n' + + ' path: /TARGET_DOC_ROOT/VERSION/' + symbol_to_file[module] + + '\n') + + symbols_in_module = module_children.get(module, []) + # Sort case-insensitive, if equal sort case sensitive (upper first) + symbols_in_module.sort(key=lambda a: (a.upper(), a)) + + for full_name in symbols_in_module: + f.write(' - title: ' + full_name[len(module) + 1:] + '\n' + ' path: /TARGET_DOC_ROOT/VERSION/' + + symbol_to_file[full_name] + '\n') + + # Write a global index containing all full names with links. + with open(os.path.join(output_dir, 'index.md'), 'w') as f: + f.write( + parser.generate_global_index('TensorFlow', parser_config.index, + parser_config.reference_resolver)) + + +def add_dict_to_dict(add_from, add_to): + for key in add_from: + if key in add_to: + add_to[key].extend(add_from[key]) + else: + add_to[key] = add_from[key] + + +# Exclude some libaries in contrib from the documentation altogether. +def _get_default_private_map(): + return {} + + +# Exclude members of some libaries. +def _get_default_do_not_descend_map(): + # TODO(wicke): Shrink this list once the modules get sealed. + return { + 'tf': ['cli', 'lib', 'wrappers'], + 'tf.contrib': [ + 'compiler', + 'factorization', + 'grid_rnn', + 'labeled_tensor', + 'ndlstm', + 'quantization', + 'session_bundle', + 'slim', + 'solvers', + 'specs', + 'tensor_forest', + 'tensorboard', + 'testing', + 'tfprof', + ], + 'tf.contrib.bayesflow': [ + 'special_math', 'stochastic_gradient_estimators', + 'stochastic_variables' + ], + 'tf.contrib.ffmpeg': ['ffmpeg_ops'], + 'tf.contrib.graph_editor': [ + 'edit', 'match', 'reroute', 'subgraph', 'transform', 'select', 'util' + ], + 'tf.contrib.keras': ['api', 'python'], + 'tf.contrib.layers': ['feature_column', 'summaries'], + 'tf.contrib.learn': [ + 'datasets', + 'head', + 'graph_actions', + 'io', + 'models', + 'monitors', + 'ops', + 'preprocessing', + 'utils', + ], + 'tf.contrib.util': ['loader'], + } + + +def extract(py_modules, private_map, do_not_descend_map): + """Extract docs from tf namespace and write them to disk.""" + # Traverse the first module. + visitor = doc_generator_visitor.DocGeneratorVisitor(py_modules[0][0]) + api_visitor = public_api.PublicAPIVisitor(visitor) + api_visitor.set_root_name(py_modules[0][0]) + add_dict_to_dict(private_map, api_visitor.private_map) + add_dict_to_dict(do_not_descend_map, api_visitor.do_not_descend_map) + + traverse.traverse(py_modules[0][1], api_visitor) + + # Traverse all py_modules after the first: + for module_name, module in py_modules[1:]: + visitor.set_root_name(module_name) + api_visitor.set_root_name(module_name) + traverse.traverse(module, api_visitor) + + return visitor + + +class _GetMarkdownTitle(py_guide_parser.PyGuideParser): + """Extract the title from a .md file.""" + + def __init__(self): + self.title = None + py_guide_parser.PyGuideParser.__init__(self) + + def process_title(self, _, title): + if self.title is None: # only use the first title + self.title = title + + +class _DocInfo(object): + """A simple struct for holding a doc's url and title.""" + + def __init__(self, url, title): + self.url = url + self.title = title + + +def build_doc_index(src_dir): + """Build an index from a keyword designating a doc to _DocInfo objects.""" + doc_index = {} + if not os.path.isabs(src_dir): + raise ValueError("'src_dir' must be an absolute path.\n" + " src_dir='%s'" % src_dir) + + if not os.path.exists(src_dir): + raise ValueError("'src_dir' path must exist.\n" + " src_dir='%s'" % src_dir) + + for dirpath, _, filenames in os.walk(src_dir): + suffix = os.path.relpath(path=dirpath, start=src_dir) + for base_name in filenames: + if not base_name.endswith('.md'): + continue + title_parser = _GetMarkdownTitle() + title_parser.process(os.path.join(dirpath, base_name)) + key_parts = os.path.join(suffix, base_name[:-3]).split('/') + if key_parts[-1] == 'index': + key_parts = key_parts[:-1] + doc_info = _DocInfo(os.path.join(suffix, base_name), title_parser.title) + doc_index[key_parts[-1]] = doc_info + if len(key_parts) > 1: + doc_index['/'.join(key_parts[-2:])] = doc_info + + return doc_index + + +class _GuideRef(object): + + def __init__(self, base_name, title, section_title, section_tag): + self.url = 'api_guides/python/' + (('%s#%s' % (base_name, section_tag)) + if section_tag else base_name) + self.link_text = (('%s > %s' % (title, section_title)) + if section_title else title) + + def make_md_link(self, url_prefix): + return '[%s](%s%s)' % (self.link_text, url_prefix, self.url) + + +class _GenerateGuideIndex(py_guide_parser.PyGuideParser): + """Turn guide files into an index from symbol name to a list of _GuideRefs.""" + + def __init__(self): + self.index = {} + py_guide_parser.PyGuideParser.__init__(self) + + def process(self, full_path, base_name): + """Index a file, reading from `full_path`, with `base_name` as the link.""" + self.full_path = full_path + self.base_name = base_name + self.title = None + self.section_title = None + self.section_tag = None + py_guide_parser.PyGuideParser.process(self, full_path) + + def process_title(self, _, title): + if self.title is None: # only use the first title + self.title = title + + def process_section(self, _, section_title, tag): + self.section_title = section_title + self.section_tag = tag + + def process_line(self, _, line): + """Index @{symbol} references as in the current file & section.""" + for match in parser.SYMBOL_REFERENCE_RE.finditer(line): + val = self.index.get(match.group(1), []) + val.append( + _GuideRef(self.base_name, self.title, self.section_title, + self.section_tag)) + self.index[match.group(1)] = val + + +def _build_guide_index(guide_src_dir): + """Return dict: symbol name -> _GuideRef from the files in `guide_src_dir`.""" + index_generator = _GenerateGuideIndex() + if os.path.exists(guide_src_dir): + for full_path, base_name in py_guide_parser.md_files_in_dir(guide_src_dir): + index_generator.process(full_path, base_name) + return index_generator.index + + +class _UpdateTags(py_guide_parser.PyGuideParser): + """Rewrites a Python guide so that each section has an explicit tag.""" + + def process_section(self, line_number, section_title, tag): + self.replace_line(line_number, '

%s

' % (tag, section_title)) + + +EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt']) + + +def _other_docs(src_dir, output_dir, reference_resolver): + """Convert all the files in `src_dir` and write results to `output_dir`.""" + header = '\n' + + # Iterate through all the source files and process them. + tag_updater = _UpdateTags() + for dirpath, _, filenames in os.walk(src_dir): + # How to get from `dirpath` to api_docs/python/ + relative_path_to_root = os.path.relpath( + path=os.path.join(src_dir, 'api_docs/python'), start=dirpath) + + # Make the directory under output_dir. + new_dir = os.path.join(output_dir, + os.path.relpath(path=dirpath, start=src_dir)) + try: + if not os.path.exists(new_dir): + os.makedirs(new_dir) + except OSError as e: + print('Creating output dir "%s" failed: %s' % (new_dir, e)) + raise + + for base_name in filenames: + if base_name in EXCLUDED: + print('Skipping excluded file %s...' % base_name) + continue + full_in_path = os.path.join(dirpath, base_name) + suffix = os.path.relpath(path=full_in_path, start=src_dir) + full_out_path = os.path.join(output_dir, suffix) + if not base_name.endswith('.md'): + print('Copying non-md file %s...' % suffix) + open(full_out_path, 'w').write(open(full_in_path).read()) + continue + if dirpath.endswith('/api_guides/python'): + print('Processing Python guide %s...' % base_name) + md_string = tag_updater.process(full_in_path) + else: + print('Processing doc %s...' % suffix) + md_string = open(full_in_path).read() + + output = reference_resolver.replace_references(md_string, + relative_path_to_root) + with open(full_out_path, 'w') as f: + f.write(header + output) + + print('Done.') + + +class DocGenerator(object): + """Main entry point for generating docs.""" + + def __init__(self): + self.argument_parser = argparse.ArgumentParser() + self._py_modules = None + self._private_map = _get_default_private_map() + self._do_not_descend_map = _get_default_do_not_descend_map() + self.yaml_toc = True + + def add_output_dir_argument(self): + self.argument_parser.add_argument( + '--output_dir', + type=str, + default=None, + required=True, + help='Directory to write docs to.') + + def add_src_dir_argument(self): + self.argument_parser.add_argument( + '--src_dir', + type=str, + default=None, + required=True, + help='Directory with the source docs.') + + def add_base_dir_argument(self, default_base_dir): + self.argument_parser.add_argument( + '--base_dir', + type=str, + default=default_base_dir, + help='Base directory to to strip from file names referenced in docs.') + + def parse_known_args(self): + flags, _ = self.argument_parser.parse_known_args() + return flags + + def add_to_private_map(self, d): + add_dict_to_dict(d, self._private_map) + + def add_to_do_not_descend_map(self, d): + add_dict_to_dict(d, self._do_not_descend_map) + + def set_private_map(self, d): + self._private_map = d + + def set_do_not_descend_map(self, d): + self._do_not_descend_map = d + + def set_py_modules(self, py_modules): + self._py_modules = py_modules + + def py_module_names(self): + if self._py_modules is None: + raise RuntimeError( + 'Must call set_py_modules() before running py_module_names().') + return [name for (name, _) in self._py_modules] + + def make_reference_resolver(self, visitor, doc_index): + return parser.ReferenceResolver.from_visitor( + visitor, doc_index, py_module_names=self.py_module_names()) + + def make_parser_config(self, visitor, reference_resolver, guide_index, + base_dir): + return parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates=visitor.duplicates, + duplicate_of=visitor.duplicate_of, + tree=visitor.tree, + index=visitor.index, + reverse_index=visitor.reverse_index, + guide_index=guide_index, + base_dir=base_dir) + + def run_extraction(self): + return extract( + self._py_modules, self._private_map, self._do_not_descend_map) + + def build(self, flags): + """Actually build the docs.""" + doc_index = build_doc_index(flags.src_dir) + visitor = self.run_extraction() + reference_resolver = self.make_reference_resolver(visitor, doc_index) + + guide_index = _build_guide_index( + os.path.join(flags.src_dir, 'api_guides/python')) + + parser_config = self.make_parser_config(visitor, reference_resolver, + guide_index, flags.base_dir) + output_dir = os.path.join(flags.output_dir, 'api_docs/python') + + write_docs(output_dir, parser_config, yaml_toc=self.yaml_toc) + _other_docs(flags.src_dir, flags.output_dir, reference_resolver) + + if parser.all_errors: + print('Errors during processing:\n ' + '\n '.join(parser.all_errors)) + return 1 + return 0 diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py new file mode 100644 index 00000000000..6e5deb6a36e --- /dev/null +++ b/tensorflow/tools/docs/generate_lib_test.py @@ -0,0 +1,151 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for doc generator traversal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +import tensorflow as tf + +from tensorflow.python import debug as tf_debug +from tensorflow.python.platform import googletest +from tensorflow.tools.docs import generate_lib +from tensorflow.tools.docs import parser + + +def test_function(): + """Docstring for test_function.""" + pass + + +class TestClass(object): + """Docstring for TestClass itself.""" + + class ChildClass(object): + """Docstring for a child class.""" + + class GrandChildClass(object): + """Docstring for a child of a child class.""" + pass + + +class DummyVisitor(object): + + def __init__(self, index, duplicate_of): + self.index = index + self.duplicate_of = duplicate_of + + +class GenerateTest(googletest.TestCase): + + def test_extraction(self): + py_modules = [('tf', tf), ('tfdbg', tf_debug)] + + try: + generate_lib.extract(py_modules, + generate_lib._get_default_private_map(), + generate_lib._get_default_do_not_descend_map()) + except RuntimeError: + print('*****************************************************************') + print('If this test fails, you have most likely introduced an unsealed') + print('module. Make sure to use remove_undocumented or similar utilities') + print('to avoid leaking symbols. See below for more information on the') + print('failure.') + print('*****************************************************************') + raise + + def test_write(self): + module = sys.modules[__name__] + + index = { + 'tf': sys, # Can be any module, this test doesn't care about content. + 'tf.TestModule': module, + 'tf.test_function': test_function, + 'tf.TestModule.test_function': test_function, + 'tf.TestModule.TestClass': TestClass, + 'tf.TestModule.TestClass.ChildClass': TestClass.ChildClass, + 'tf.TestModule.TestClass.ChildClass.GrandChildClass': + TestClass.ChildClass.GrandChildClass, + } + + tree = { + 'tf': ['TestModule', 'test_function'], + 'tf.TestModule': ['test_function', 'TestClass'], + 'tf.TestModule.TestClass': ['ChildClass'], + 'tf.TestModule.TestClass.ChildClass': ['GrandChildClass'], + 'tf.TestModule.TestClass.ChildClass.GrandChildClass': [] + } + + duplicate_of = {'tf.test_function': 'tf.TestModule.test_function'} + + duplicates = { + 'tf.TestModule.test_function': [ + 'tf.test_function', 'tf.TestModule.test_function' + ] + } + + base_dir = os.path.dirname(__file__) + + visitor = DummyVisitor(index, duplicate_of) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + + parser_config = parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates=duplicates, + duplicate_of=duplicate_of, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir=base_dir) + + output_dir = googletest.GetTempDir() + + generate_lib.write_docs(output_dir, parser_config, yaml_toc=True) + + # Make sure that the right files are written to disk. + self.assertTrue(os.path.exists(os.path.join(output_dir, 'index.md'))) + self.assertTrue(os.path.exists(os.path.join(output_dir, 'tf.md'))) + self.assertTrue(os.path.exists(os.path.join(output_dir, '_toc.yaml'))) + self.assertTrue( + os.path.exists(os.path.join(output_dir, 'tf/TestModule.md'))) + self.assertFalse( + os.path.exists(os.path.join(output_dir, 'tf/test_function.md'))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, 'tf/TestModule/TestClass.md'))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, + 'tf/TestModule/TestClass/ChildClass.md'))) + self.assertTrue( + os.path.exists( + os.path.join( + output_dir, + 'tf/TestModule/TestClass/ChildClass/GrandChildClass.md'))) + # Make sure that duplicates are not written + self.assertTrue( + os.path.exists( + os.path.join(output_dir, 'tf/TestModule/test_function.md'))) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/docs/generate_test.py b/tensorflow/tools/docs/generate_test.py deleted file mode 100644 index 4594676109c..00000000000 --- a/tensorflow/tools/docs/generate_test.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for doc generator traversal.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import sys -import tempfile - -from tensorflow.python.platform import googletest -from tensorflow.tools.docs import generate - - -def test_function(): - """Docstring for test_function.""" - pass - - -class TestClass(object): - """Docstring for TestClass itself.""" - - class ChildClass(object): - """Docstring for a child class.""" - - class GrandChildClass(object): - """Docstring for a child of a child class.""" - pass - - -class GenerateTest(googletest.TestCase): - - def test_extraction(self): - try: - generate.extract() - except RuntimeError: - print('*****************************************************************') - print('If this test fails, you have most likely introduced an unsealed') - print('module. Make sure to use remove_undocumented or similar utilities') - print('to avoid leaking symbols. See below for more information on the') - print('failure.') - print('*****************************************************************') - raise - - def test_write(self): - module = sys.modules[__name__] - - index = { - '': sys, # Can be any module, this test doesn't care about content. - 'TestModule': module, - 'test_function': test_function, - 'TestModule.test_function': test_function, - 'TestModule.TestClass': TestClass, - 'TestModule.TestClass.ChildClass': TestClass.ChildClass, - 'TestModule.TestClass.ChildClass.GrandChildClass': - TestClass.ChildClass.GrandChildClass, - } - - tree = { - '': ['TestModule', 'test_function'], - 'TestModule': ['test_function', 'TestClass'], - 'TestModule.TestClass': ['ChildClass'], - 'TestModule.TestClass.ChildClass': ['GrandChildClass'], - 'TestModule.TestClass.ChildClass.GrandChildClass': [] - } - - duplicate_of = { - 'TestModule.test_function': 'test_function' - } - - duplicates = { - 'test_function': ['test_function', 'TestModule.test_function'] - } - - output_dir = tempfile.mkdtemp() - base_dir = os.path.dirname(__file__) - - generate.write_docs(output_dir, base_dir, - duplicate_of, duplicates, - index, tree) - - # Make sure that the right files are written to disk. - self.assertTrue(os.path.exists(os.path.join(output_dir, 'index.md'))) - self.assertTrue(os.path.exists(os.path.join(output_dir, 'full_index.md'))) - self.assertTrue(os.path.exists(os.path.join(output_dir, 'TestModule.md'))) - self.assertTrue(os.path.exists(os.path.join( - output_dir, 'test_function.md'))) - self.assertTrue(os.path.exists(os.path.join( - output_dir, 'TestModule/TestClass.md'))) - self.assertTrue(os.path.exists(os.path.join( - output_dir, 'TestModule/TestClass/ChildClass.md'))) - self.assertTrue(os.path.exists(os.path.join( - output_dir, 'TestModule/TestClass/ChildClass/GrandChildClass.md'))) - # Make sure that duplicates are not written - self.assertFalse(os.path.exists(os.path.join( - output_dir, 'TestModule/test_function.md'))) - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/tools/docs/make_py_guides.py b/tensorflow/tools/docs/make_py_guides.py deleted file mode 100644 index a5264f2f8dc..00000000000 --- a/tensorflow/tools/docs/make_py_guides.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Convert @{symbol} to MarkDown links in the Python API guides.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import os - -from tensorflow.tools.docs import generate -from tensorflow.tools.docs import parser - - -def _md_files_in_dir(input_dir): - all_in_dir = [(os.path.join(input_dir, f), f) for f in os.listdir(input_dir)] - return [(full, f) for full, f in all_in_dir - if os.path.isfile(full) and f.endswith('.md')] - - -def _main(input_dir, output_dir): - """Convert all the files in `input_dir` and write results to `output_dir`.""" - visitor = generate.extract() - duplicate_of, unused_duplicates = visitor.find_duplicates() - - # Make output_dir. - try: - if not os.path.exists(output_dir): - os.makedirs(output_dir) - except OSError as e: - print('Creating output dir "%s" failed: %s' % (output_dir, e)) - raise - - # How to get from api_guides/python/ to api_docs/python/ - relative_path_to_root = '../../api_docs/python/' - - # Iterate through all the source files and process them. - for full_path, base_name in _md_files_in_dir(input_dir): - print('Processing %s...' % base_name) - md_string = open(full_path).read() - output = parser.replace_references( - md_string, relative_path_to_root, duplicate_of) - open(os.path.join(output_dir, base_name), 'w').write(output) - print('Done.') - - -if __name__ == '__main__': - argument_parser = argparse.ArgumentParser() - argument_parser.add_argument( - '--input_dir', - type=str, - default=None, - required=True, - help='Directory to copy docs from.' - ) - argument_parser.add_argument( - '--output_dir', - type=str, - default=None, - required=True, - help='Directory to write docs to. Will be created, must not exist.' - ) - flags, _ = argument_parser.parse_known_args() - if os.path.exists(flags.output_dir): - raise RuntimeError('output_dir %s exists.\n' - 'Cowardly refusing to wipe it, please do that yourself.' - % flags.output_dir) - - _main(flags.input_dir, flags.output_dir) diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 3a1a7cb82e9..7ae1d2abd9a 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -18,16 +18,31 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import ast +import collections import functools -import inspect +import json import os import re +import codegen import six +from google.protobuf.message import Message as ProtoMessage +from tensorflow.python.util import tf_inspect + + # A regular expression capturing a python indentifier. IDENTIFIER_RE = '[a-zA-Z_][a-zA-Z0-9_]*' +# Log of all reported errors +all_errors = [] + + +def log_error(s): + all_errors.append(s) + print('ERROR:', s) + def documentation_path(full_name): """Returns the file path for the documentation for the given API symbol. @@ -35,8 +50,7 @@ def documentation_path(full_name): Given the fully qualified name of a library symbol, compute the path to which to write the documentation for that symbol (relative to a base directory). Documentation files are organized into directories that mirror the python - module/class structure. The path for the top-level module (whose full name is - '') is 'index.md'. + module/class structure. Args: full_name: Fully qualified name of a library symbol. @@ -44,12 +58,7 @@ def documentation_path(full_name): Returns: The file path to which to write the documentation for `full_name`. """ - # The main page is special, since it has no name in here. - if not full_name: - dirs = ['index'] - else: - dirs = full_name.split('.') - + dirs = full_name.split('.') return os.path.join(*dirs) + '.md' @@ -63,181 +72,466 @@ def _get_raw_docstring(py_object): Returns: The docstring, or the empty string if no docstring was found. """ - # For object instances, inspect.getdoc does give us the docstring of their + # For object instances, tf_inspect.getdoc does give us the docstring of their # type, which is not what we want. Only return the docstring if it is useful. - if (inspect.isclass(py_object) or inspect.ismethod(py_object) or - inspect.isfunction(py_object) or inspect.ismodule(py_object) or + if (tf_inspect.isclass(py_object) or tf_inspect.ismethod(py_object) or + tf_inspect.isfunction(py_object) or tf_inspect.ismodule(py_object) or isinstance(py_object, property)): - return inspect.getdoc(py_object) or '' + return tf_inspect.getdoc(py_object) or '' else: return '' -def _get_brief_docstring(py_object): - """Gets the one line docstring of a python object.""" - return _get_raw_docstring(py_object).split('\n')[0] +# A regular expression for capturing a @{symbol} reference. +SYMBOL_REFERENCE_RE = re.compile(r'@\{([^}]+)\}') -def _reference_to_link(ref_full_name, relative_path_to_root, duplicate_of): - """Resolve a "@{symbol}" reference to a relative path, respecting duplicates. - - The input to this function should already be stripped of the '@' and '{}', and - its output is only the link, not the full Markdown. +class ReferenceResolver(object): + """Class for replacing @{...} references with Markdown links. Args: - ref_full_name: The fully qualified name of the symbol to link to. - relative_path_to_root: The relative path from the location of the current - document to the root of the API documentation. - duplicate_of: A map from duplicate full names to master names. - - Returns: - A relative path that links from the documentation page of `from_full_name` - to the documentation page of `ref_full_name`. + duplicate_of: A map from duplicate names to preferred names of API + symbols. + doc_index: A `dict` mapping symbol name strings to objects with `url` + and `title` fields. Used to resolve @{$doc} references in docstrings. + index: A map from all full names to python objects. + py_module_names: A list of string names of Python modules. """ - master_name = duplicate_of.get(ref_full_name, ref_full_name) - ref_path = documentation_path(master_name) - return os.path.join(relative_path_to_root, ref_path) + + def __init__(self, duplicate_of, doc_index, is_class, is_module, + py_module_names): + self._duplicate_of = duplicate_of + self._doc_index = doc_index + self._is_class = is_class + self._is_module = is_module + self._all_names = set(is_class.keys()) + self._py_module_names = py_module_names + + @classmethod + def from_visitor(cls, visitor, doc_index, **kwargs): + """A factory function for building a ReferenceResolver from a visitor. + + Args: + visitor: an instance of `DocGeneratorVisitor` + doc_index: a dictionary mapping document names to references objects with + "title" and "url" fields + **kwargs: all remaining args are passed to the constructor + Returns: + an instance of `ReferenceResolver` () + """ + is_class = { + name: tf_inspect.isclass(visitor.index[name]) + for name, obj in visitor.index.items() + } + + is_module = { + name: tf_inspect.ismodule(visitor.index[name]) + for name, obj in visitor.index.items() + } + + return cls( + duplicate_of=visitor.duplicate_of, + doc_index=doc_index, + is_class=is_class, + is_module=is_module, + **kwargs) + + @classmethod + def from_json_file(cls, filepath, doc_index): + with open(filepath) as f: + json_dict = json.load(f) + + return cls(doc_index=doc_index, **json_dict) + + def to_json_file(self, filepath): + """Converts the RefenceResolver to json and writes it to the specified file. + + Args: + filepath: The file path to write the json to. + """ + json_dict = {} + for key, value in self.__dict__.items(): + # Drop these two fields. `_doc_index` is not serializable. `_all_names` is + # generated by the constructor. + if key in ('_doc_index', '_all_names'): + continue + + # Strip off any leading underscores on field names as these are not + # recognized by the constructor. + json_dict[key.lstrip('_')] = value + + with open(filepath, 'w') as f: + json.dump(json_dict, f) + + def replace_references(self, string, relative_path_to_root): + """Replace "@{symbol}" references with links to symbol's documentation page. + + This functions finds all occurrences of "@{symbol}" in `string` + and replaces them with markdown links to the documentation page + for "symbol". + + `relative_path_to_root` is the relative path from the document + that contains the "@{symbol}" reference to the root of the API + documentation that is linked to. If the containing page is part of + the same API docset, `relative_path_to_root` can be set to + `os.path.dirname(documentation_path(name))`, where `name` is the + python name of the object whose documentation page the reference + lives on. + + Args: + string: A string in which "@{symbol}" references should be replaced. + relative_path_to_root: The relative path from the containing document to + the root of the API documentation that is being linked to. + + Returns: + `string`, with "@{symbol}" references replaced by Markdown links. + """ + return re.sub(SYMBOL_REFERENCE_RE, + lambda match: self._one_ref(match.group(1), # pylint: disable=g-long-lambda + relative_path_to_root), + string) + + def python_link(self, link_text, ref_full_name, relative_path_to_root, + code_ref=True): + """Resolve a "@{python symbol}" reference to a Markdown link. + + This will pick the canonical location for duplicate symbols. The + input to this function should already be stripped of the '@' and + '{}'. This function returns a Markdown link. If `code_ref` is + true, it is assumed that this is a code reference, so the link + text will be rendered as code (using backticks). + `link_text` should refer to a library symbol, starting with 'tf.'. + + Args: + link_text: The text of the Markdown link. + ref_full_name: The fully qualified name of the symbol to link to. + relative_path_to_root: The relative path from the location of the current + document to the root of the API documentation. + code_ref: If true (the default), put `link_text` in `...`. + + Returns: + A markdown link to the documentation page of `ref_full_name`. + """ + link = self.reference_to_url(ref_full_name, relative_path_to_root) + if code_ref: + return '[`%s`](%s)' % (link_text, link) + else: + return '[%s](%s)' % (link_text, link) + + def py_master_name(self, full_name): + """Return the master name for a Python symbol name.""" + return self._duplicate_of.get(full_name, full_name) + + def reference_to_url(self, ref_full_name, relative_path_to_root): + """Resolve a "@{python symbol}" reference to a relative path. + + The input to this function should already be stripped of the '@' + and '{}', and its output is only the link, not the full Markdown. + + If `ref_full_name` is the name of a class member, method, or property, the + link will point to the page of the containing class, and it will include the + method name as an anchor. For example, `tf.module.MyClass.my_method` will be + translated into a link to + `os.join.path(relative_path_to_root, 'tf/module/MyClass.md#my_method')`. + + Args: + ref_full_name: The fully qualified name of the symbol to link to. + relative_path_to_root: The relative path from the location of the current + document to the root of the API documentation. + + Returns: + A relative path that links from the documentation page of `from_full_name` + to the documentation page of `ref_full_name`. + + Raises: + RuntimeError: If `ref_full_name` is not documented. + """ + master_name = self._duplicate_of.get(ref_full_name, ref_full_name) + + # Check whether this link exists + if master_name not in self._all_names: + # TODO(josh11b): Make error reporting more uniform. + print('ERROR: Cannot make link to %s (original: %s): Not in index.' % + (master_name, ref_full_name)) + return 'BROKEN_LINK' + + # If this is a member of a class, link to the class page with an anchor. + ref_path = None + if not (self._is_class[master_name] or self._is_module[master_name]): + idents = master_name.split('.') + if len(idents) > 1: + class_name = '.'.join(idents[:-1]) + assert class_name in self._all_names + if self._is_class[class_name]: + ref_path = documentation_path(class_name) + '#%s' % idents[-1] + + if not ref_path: + ref_path = documentation_path(master_name) + + return os.path.join(relative_path_to_root, ref_path) + + def _one_ref(self, string, relative_path_to_root): + """Return a link for a single "@{symbol}" reference.""" + # Look for link text after $. + dollar = string.rfind('$') + if dollar > 0: # Ignore $ in first character + link_text = string[dollar + 1:] + string = string[:dollar] + manual_link_text = True + else: + link_text = string + manual_link_text = False + + # Handle different types of references. + if string.startswith('$'): # Doc reference + return self._doc_link( + string, link_text, manual_link_text, relative_path_to_root) + + elif string.startswith('tensorflow::'): + # C++ symbol + return self._cc_link( + string, link_text, manual_link_text, relative_path_to_root) + + else: + is_python = False + for py_module_name in self._py_module_names: + if string == py_module_name or string.startswith(py_module_name + '.'): + is_python = True + break + if is_python: # Python symbol + return self.python_link(link_text, string, relative_path_to_root, + code_ref=not manual_link_text) + + # Error! + log_error('Did not understand "@{%s}"' % string) + return 'ERROR:%s' % string + + def _doc_link(self, string, link_text, manual_link_text, + relative_path_to_root): + """Generate a link for a @{$...} reference.""" + string = string[1:] # remove leading $ + + # If string has a #, split that part into `hash_tag` + hash_pos = string.find('#') + if hash_pos > -1: + hash_tag = string[hash_pos:] + string = string[:hash_pos] + else: + hash_tag = '' + + if string in self._doc_index: + if not manual_link_text: link_text = self._doc_index[string].title + url = os.path.normpath(os.path.join( + relative_path_to_root, '../..', self._doc_index[string].url)) + return '[%s](%s%s)' % (link_text, url, hash_tag) + return self._doc_missing(string, hash_tag, link_text, manual_link_text, + relative_path_to_root) + + def _doc_missing(self, string, unused_hash_tag, link_text, + unused_manual_link_text, unused_relative_path_to_root): + """Generate an error for unrecognized @{$...} references.""" + log_error('Handle doc reference "@{$%s}"' % string) + return link_text + + def _cc_link(self, string, link_text, unused_manual_link_text, + relative_path_to_root): + """Generate a link for a @{tensorflow::...} reference.""" + # TODO(josh11b): Fix this hard-coding of paths. + if string == 'tensorflow::ClientSession': + ret = 'class/tensorflow/client-session.md' + elif string == 'tensorflow::Scope': + ret = 'class/tensorflow/scope.md' + elif string == 'tensorflow::Status': + ret = 'class/tensorflow/status.md' + elif string == 'tensorflow::Tensor': + ret = 'class/tensorflow/tensor.md' + elif string == 'tensorflow::ops::Const': + ret = 'namespace/tensorflow/ops.md#const' + else: + log_error('Handle C++ reference "@{%s}"' % string) + return 'TODO_C++:%s' % string + # relative_path_to_root gets you to api_docs/python, we go from there + # to api_docs/cc, and then add ret. + cc_relative_path = os.path.normpath(os.path.join( + relative_path_to_root, '../cc', ret)) + return '[`%s`](%s)' % (link_text, cc_relative_path) -def _markdown_link(link_text, ref_full_name, relative_path_to_root, - duplicate_of): - """Resolve a "@{symbol}" reference to a Markdown link, respecting duplicates. - - The input to this function should already be stripped of the '@' and '{}'. - This function returns a Markdown link. It is assumed that this is a code - reference, so the link text will always be rendered as code (using backticks). - - `link_text` should refer to a library symbol. You can either refer to it with - or without the `tf.` prefix. +# TODO(aselle): Collect these into a big list for all modules and functions +# and make a rosetta stone page. +def _handle_compatibility(doc): + """Parse and remove compatibility blocks from the main docstring. Args: - link_text: The text of the Markdown link. - ref_full_name: The fully qualified name of the symbol to link to - (may optionally include 'tf.'). - relative_path_to_root: The relative path from the location of the current - document to the root of the API documentation. - duplicate_of: A map from duplicate full names to master names. + doc: The docstring that contains compatibility notes" Returns: - A markdown link from the documentation page of `from_full_name` - to the documentation page of `ref_full_name`. + a tuple of the modified doc string and a hash that maps from compatibility + note type to the text of the note. """ - if ref_full_name.startswith('tf.'): - ref_full_name = ref_full_name[3:] - - return '[`%s`](%s)' % ( - link_text, - _reference_to_link(ref_full_name, relative_path_to_root, duplicate_of)) + compatibility_notes = {} + match_compatibility = re.compile(r'[ \t]*@compatibility\((\w+)\)\s*\n' + r'((?:[^@\n]*\n)+)' + r'\s*@end_compatibility') + for f in match_compatibility.finditer(doc): + compatibility_notes[f.group(1)] = f.group(2) + return match_compatibility.subn(r'', doc)[0], compatibility_notes -def replace_references(string, relative_path_to_root, duplicate_of): - """Replace "@{symbol}" references with links to symbol's documentation page. - - This functions finds all occurrences of "@{symbol}" in `string` and replaces - them with markdown links to the documentation page for "symbol". - - `relative_path_to_root` is the relative path from the document that contains - the "@{symbol}" reference to the root of the API documentation that is linked - to. If the containing page is part of the same API docset, - `relative_path_to_root` can be set to - `os.path.dirname(documentation_path(name))`, where `name` is the python name - of the object whose documentation page the reference lives on. +def _gen_pairs(items): + """Given an list of items [a,b,a,b...], generate pairs [(a,b),(a,b)...]. Args: - string: A string in which "@{symbol}" references should be replaced. - relative_path_to_root: The relative path from the contianing document to the - root of the API documentation that is being linked to. - duplicate_of: A map from duplicate names to preferred names of API symbols. + items: A list of items (length must be even) + + Yields: + The original items, in pairs + """ + assert len(items) % 2 == 0 + items = iter(items) + while True: + yield next(items), next(items) + + +class _FunctionDetail( + collections.namedtuple('_FunctionDetail', ['keyword', 'header', 'items'])): + """A simple class to contain function details. + + Composed of a "keyword", a possibly empty "header" string, and a possibly + empty + list of key-value pair "items". + """ + __slots__ = [] + + def __str__(self): + """Return the original string that represents the function detail.""" + parts = [self.keyword + ':\n'] + parts.append(self.header) + for key, value in self.items: + parts.append(' ' + key + ':') + parts.append(value) + + return ''.join(parts) + + +def _parse_function_details(docstring): + r"""Given a docstring, split off the header and parse the function details. + + For example the docstring of tf.nn.relu: + + '''Computes rectified linear: `max(features, 0)`. + + Args: + features: A `Tensor`. Must be one of the following types: `float32`, + `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`, `uint16`, + `half`. + name: A name for the operation (optional). Returns: - `string`, with "@{symbol}" references replaced by Markdown links. + A `Tensor`. Has the same type as `features`. + ''' + + This is parsed, and returned as: + + ``` + ('Computes rectified linear: `max(features, 0)`.\n\n', [ + _FunctionDetail( + keyword='Args', + header='', + items=[ + ('features', ' A `Tensor`. Must be ...'), + ('name', ' A name for the operation (optional).\n\n')]), + _FunctionDetail( + keyword='Returns', + header=' A `Tensor`. Has the same type as `features`.', + items=[]) + ]) + ``` + + Args: + docstring: The docstring to parse + + Returns: + A (header, function_details) pair, where header is a string and + function_details is a (possibly empty) list of `_FunctionDetail` objects. """ - full_name_re = '%s(.%s)*' % (IDENTIFIER_RE, IDENTIFIER_RE) - symbol_reference_re = re.compile(r'@\{(' + full_name_re + r')\}') - match = symbol_reference_re.search(string) - while match: - symbol_name = match.group(1) - link_text = _markdown_link(symbol_name, symbol_name, - relative_path_to_root, duplicate_of) - # Remove only the '@symbol' part of the match, and replace with the link. - string = string[:match.start()] + link_text + string[match.end():] - match = symbol_reference_re.search(string, - pos=match.start() + len(link_text)) - return string + detail_keywords = '|'.join([ + 'Args', 'Arguments', 'Fields', 'Returns', 'Yields', 'Raises', 'Attributes' + ]) + tag_re = re.compile('(?<=\n)(' + detail_keywords + '):\n', re.MULTILINE) + parts = tag_re.split(docstring) + + # The first part is the main docstring + docstring = parts[0] + + # Everything else alternates keyword-content + pairs = list(_gen_pairs(parts[1:])) + + function_details = [] + item_re = re.compile(r'^ (\w+):', re.MULTILINE) + + for keyword, content in pairs: + content = item_re.split(content) + header = content[0] + items = list(_gen_pairs(content[1:])) + + function_details.append(_FunctionDetail(keyword, header, items)) + + return docstring, function_details -def _md_docstring(py_object, relative_path_to_root, duplicate_of): - """Get the docstring from an object and make it into nice Markdown. +_DocstringInfo = collections.namedtuple('_DocstringInfo', [ + 'brief', 'docstring', 'function_details', 'compatibility' +]) + + +def _parse_md_docstring(py_object, relative_path_to_root, reference_resolver): + """Parse the object's docstring and return a `_DocstringInfo`. + + This function clears @@'s from the docstring, and replaces @{} references + with markdown links. For links within the same set of docs, the `relative_path_to_root` for a - docstring on the page for `full_name` can be set to + docstring on the page for `full_name` can be set to: ```python relative_path_to_root = os.path.relpath( - os.path.dirname(documentation_path(full_name)) or '.', '.') + path='.', start=os.path.dirname(documentation_path(full_name)) or '.') ``` Args: py_object: A python object to retrieve the docs for (class, function/method, or module). relative_path_to_root: The relative path from the location of the current - document to the root of the API documentation. This is used to compute - links for "@symbol" references. - duplicate_of: A map from duplicate symbol names to master names. Used to - resolve "@symbol" references. + document to the root of the Python API documentation. This is used to + compute links for "@{symbol}" references. + reference_resolver: An instance of ReferenceResolver. Returns: - The docstring, or the empty string if no docstring was found. + A _DocstringInfo object, all fields will be empty if no docstring was found. """ # TODO(wicke): If this is a partial, use the .func docstring and add a note. raw_docstring = _get_raw_docstring(py_object) - raw_lines = raw_docstring.split('\n') - # Define regular expressions used during parsing below. - symbol_list_item_re = re.compile(r'^ (%s): ' % IDENTIFIER_RE) - section_re = re.compile(r'^(\w+):\s*$') + raw_docstring = reference_resolver.replace_references( + raw_docstring, relative_path_to_root) - # Translate docstring line by line. - in_special_section = False - lines = [] + atat_re = re.compile(r' *@@[a-zA-Z_.0-9]+ *$') + raw_docstring = '\n'.join( + line for line in raw_docstring.split('\n') if not atat_re.match(line)) - def is_section_start(i): - # Previous line is empty, line i is "Word:", and next line is indented. - return (i > 0 and not raw_lines[i-1].strip() and - re.match(section_re, raw_lines[i]) and - len(raw_lines) > i+1 and raw_lines[i+1].startswith(' ')) + docstring, compatibility = _handle_compatibility(raw_docstring) + docstring, function_details = _parse_function_details(docstring) - for i, line in enumerate(raw_lines): - if not in_special_section and is_section_start(i): - in_special_section = True - lines.append('#### ' + section_re.sub(r'\1:', line)) - lines.append('') - continue - - # If the next line starts a new section, this one ends. Add an extra line. - if in_special_section and is_section_start(i+1): - in_special_section = False - lines.append('') - - if in_special_section: - # Translate symbols in 'Args:', 'Parameters:', 'Raises:', etc. sections. - lines.append(symbol_list_item_re.sub(r'* `\1`: ', line)) - else: - lines.append(line) - - docstring = '\n'.join(lines) - - # TODO(deannarubin): Improve formatting for devsite - # TODO(deannarubin): Interpret @compatibility and other formatting notes. - - return replace_references(docstring, relative_path_to_root, duplicate_of) + return _DocstringInfo( + docstring.split('\n')[0], docstring, function_details, compatibility) def _get_arg_spec(func): """Extracts signature information from a function or functools.partial object. - For functions, uses `inspect.getargspec`. For `functools.partial` objects, + For functions, uses `tf_inspect.getargspec`. For `functools.partial` objects, corrects the signature of the underlying function to take into account the removed arguments. @@ -246,11 +540,11 @@ def _get_arg_spec(func): Returns: An `ArgSpec` namedtuple `(args, varargs, keywords, defaults)`, as returned - by `inspect.getargspec`. + by `tf_inspect.getargspec`. """ # getargspec does not work for functools.partial objects directly. if isinstance(func, functools.partial): - argspec = inspect.getargspec(func.func) + argspec = tf_inspect.getargspec(func.func) # Remove the args from the original function that have been used up. first_default_arg = ( len(argspec.args or []) - len(argspec.defaults or [])) @@ -273,19 +567,24 @@ def _get_arg_spec(func): argspec_defaults.pop(i-first_default_arg) else: first_default_arg -= 1 - return inspect.ArgSpec(args=argspec_args, - varargs=argspec.varargs, - keywords=argspec.keywords, - defaults=tuple(argspec_defaults)) + return tf_inspect.ArgSpec(args=argspec_args, + varargs=argspec.varargs, + keywords=argspec.keywords, + defaults=tuple(argspec_defaults)) else: # Regular function or method, getargspec will work fine. - return inspect.getargspec(func) + return tf_inspect.getargspec(func) -def _generate_signature(func): - """Given a function, returns a string representing its args. +def _remove_first_line_indent(string): + indent = len(re.match(r'^\s*', string).group(0)) + return '\n'.join([line[indent:] for line in string.split('\n')]) - This function produces a string representing the arguments to a python - function, including surrounding parentheses. It uses inspect.getargspec, which + +def _generate_signature(func, reverse_index): + """Given a function, returns a list of strings representing its args. + + This function produces a list of strings representing the arguments to a + python function. It uses tf_inspect.getargspec, which does not generalize well to Python 3.x, which is more flexible in how *args and **kwargs are handled. This is not a problem in TF, since we have to remain compatible to Python 2.7 anyway. @@ -297,16 +596,14 @@ def _generate_signature(func): document, it should be typeset as code (using backticks), or escaped. Args: - func: A function of method to extract the signature for (anything - `inspect.getargspec` will accept). + func: A function, method, or functools.partial to extract the signature for. + reverse_index: A map from object ids to canonical full names to use. Returns: - A string representing the signature of `func` as python code. + A list of strings representing the argument signature of `func` as python + code. """ - # This produces poor signatures for decorated functions. - # TODO(wicke): We need to use something like the decorator module to fix it. - args_list = [] argspec = _get_arg_spec(func) @@ -324,14 +621,49 @@ def _generate_signature(func): # Add all args with defaults. if argspec.defaults: - for arg, default in zip( - argspec.args[first_arg_with_default:], argspec.defaults): - # Some callables don't have __name__, fall back to including their repr. - # TODO(wicke): This could be improved at least for common cases. - if callable(default) and hasattr(default, '__name__'): - args_list.append('%s=%s' % (arg, default.__name__)) + try: + source = _remove_first_line_indent(tf_inspect.getsource(func)) + func_ast = ast.parse(source) + ast_defaults = func_ast.body[0].args.defaults + except IOError: # If this is a builtin, getsource fails with IOError + # If we cannot get the source, assume the AST would be equal to the repr + # of the defaults. + ast_defaults = [None] * len(argspec.defaults) + + for arg, default, ast_default in zip( + argspec.args[first_arg_with_default:], argspec.defaults, ast_defaults): + if id(default) in reverse_index: + default_text = reverse_index[id(default)] + elif ast_default is not None: + default_text = codegen.to_source(ast_default) + if default_text != repr(default): + # This may be an internal name. If so, handle the ones we know about. + # TODO(wicke): This should be replaced with a lookup in the index. + # TODO(wicke): (replace first ident with tf., check if in index) + internal_names = { + 'ops.GraphKeys': 'tf.GraphKeys', + '_ops.GraphKeys': 'tf.GraphKeys', + 'init_ops.zeros_initializer': 'tf.zeros_initializer', + 'init_ops.ones_initializer': 'tf.ones_initializer', + 'saver_pb2.SaverDef': 'tf.train.SaverDef', + } + full_name_re = '^%s(.%s)+' % (IDENTIFIER_RE, IDENTIFIER_RE) + match = re.match(full_name_re, default_text) + if match: + lookup_text = default_text + for internal_name, public_name in six.iteritems(internal_names): + if match.group(0).startswith(internal_name): + lookup_text = public_name + default_text[len(internal_name):] + break + if default_text is lookup_text: + print('WARNING: Using default arg, failed lookup: %s, repr: %r' % + (default_text, default)) + else: + default_text = lookup_text else: - args_list.append('%s=%r' % (arg, default)) + default_text = repr(default) + + args_list.append('%s=%s' % (arg, default_text)) # Add *args and *kwargs. if argspec.varargs: @@ -339,198 +671,624 @@ def _generate_signature(func): if argspec.keywords: args_list.append('**' + argspec.keywords) - return '(%s)' % ', '.join(args_list) + return args_list -def _generate_markdown_for_function(full_name, duplicate_names, - function, duplicate_of): - """Generate Markdown docs for a function or method. +def _get_guides_markdown(duplicate_names, guide_index, relative_path): + all_guides = [] + for name in duplicate_names: + all_guides.extend(guide_index.get(name, [])) + if not all_guides: return '' + prefix = '../' * (relative_path.count('/') + 3) + links = sorted(set([guide_ref.make_md_link(prefix) + for guide_ref in all_guides])) + return 'See the guide%s: %s\n\n' % ( + 's' if len(links) > 1 else '', ', '.join(links)) - This function creates a documentation page for a function. It uses the - function name (incl. signature) as the title, followed by a list of duplicate - names (if there are any), and the Markdown formatted docstring of the - function. - Args: - full_name: The preferred name of the function. Used in the title. Must not - be present in `duplicate_of` (master names never are). - duplicate_names: A sorted list of alternative names (incl. `full_name`). - function: The python object referenced by `full_name`. - duplicate_of: A map of duplicate full names to master names. Used to resolve - @{symbol} references in the docstring. +def _get_defining_class(py_class, name): + for cls in tf_inspect.getmro(py_class): + if name in cls.__dict__: + return cls + return None - Returns: - A string that can be written to a documentation file for this function. + +class _LinkInfo( + collections.namedtuple( + '_LinkInfo', ['short_name', 'full_name', 'obj', 'doc', 'url'])): + + __slots__ = [] + + def is_link(self): + return True + + +class _OtherMemberInfo( + collections.namedtuple('_OtherMemberInfo', + ['short_name', 'full_name', 'obj', 'doc'])): + + __slots__ = [] + + def is_link(self): + return False + + +_PropertyInfo = collections.namedtuple( + '_PropertyInfo', ['short_name', 'full_name', 'obj', 'doc']) + +_MethodInfo = collections.namedtuple( + '_MethodInfo', ['short_name', 'full_name', 'obj', 'doc', 'signature']) + + +class _FunctionPageInfo(object): + """Collects docs For a function Page.""" + + def __init__(self, full_name): + self._full_name = full_name + self._defined_in = None + self._aliases = None + self._doc = None + self._guides = None + + self._signature = None + + def for_function(self): + return True + + def for_class(self): + return False + + def for_module(self): + return False + + @property + def full_name(self): + return self._full_name + + @property + def short_name(self): + return self._full_name.split('.')[-1] + + @property + def defined_in(self): + return self._defined_in + + def set_defined_in(self, defined_in): + assert self.defined_in is None + self._defined_in = defined_in + + @property + def aliases(self): + return self._aliases + + def set_aliases(self, aliases): + assert self.aliases is None + self._aliases = aliases + + @property + def doc(self): + return self._doc + + def set_doc(self, doc): + assert self.doc is None + self._doc = doc + + @property + def guides(self): + return self._guides + + def set_guides(self, guides): + assert self.guides is None + self._guides = guides + + @property + def signature(self): + return self._signature + + def set_signature(self, function, reverse_index): + """Attach the function's signature. + + Args: + function: The python function being documented. + reverse_index: A map from object ids in the index to full names. + """ + + assert self.signature is None + self._signature = _generate_signature(function, reverse_index) + + +class _ClassPageInfo(object): + """Collects docs for a class page. + + Attributes: + full_name: The fully qualified name of the object at the master + location. Aka `master_name`. For example: `tf.nn.sigmoid`. + short_name: The last component of the `full_name`. For example: `sigmoid`. + defined_in: The path to the file where this object is defined. + aliases: The list of all fully qualified names for the locations where the + object is visible in the public api. This includes the master location. + doc: A `_DocstringInfo` object representing the object's docstring (can be + created with `_parse_md_docstring`). + guides: A markdown string, of back links pointing to the api_guides that + reference this object. + bases: A list of `_LinkInfo` objects pointing to the docs for the parent + classes. + properties: A list of `_PropertyInfo` objects documenting the class' + properties (attributes that use `@property`). + methods: A list of `_MethodInfo` objects documenting the class' methods. + classes: A list of `_LinkInfo` objects pointing to docs for any nested + classes. + other_members: A list of `_OtherMemberInfo` objects documenting any other + object's defined inside the class object (mostly enum style fields). """ - # TODO(wicke): Make sure this works for partials. - relative_path = os.path.relpath( - os.path.dirname(documentation_path(full_name)) or '.', '.') - docstring = _md_docstring(function, relative_path, duplicate_of) - signature = _generate_signature(function) - if duplicate_names: - aliases = '\n'.join(['### `%s`' % (name + signature) - for name in duplicate_names]) - aliases += '\n\n' - else: - aliases = '' + def __init__(self, full_name): + self._full_name = full_name + self._defined_in = None + self._aliases = None + self._doc = None + self._guides = None - return '#`%s%s`\n\n%s%s' % (full_name, signature, aliases, docstring) + self._bases = None + self._properties = [] + self._methods = [] + self._classes = [] + self._other_members = [] + + def for_function(self): + """Returns true if this object documents a function.""" + return False + + def for_class(self): + """Returns true if this object documents a class.""" + return True + + def for_module(self): + """Returns true if this object documents a module.""" + return False + + @property + def full_name(self): + """Returns the documented object's fully qualified name.""" + return self._full_name + + @property + def short_name(self): + """Returns the documented object's short name.""" + return self._full_name.split('.')[-1] + + @property + def defined_in(self): + """Returns the path to the file where the documented object is defined.""" + return self._defined_in + + def set_defined_in(self, defined_in): + """Sets the `defined_in` path.""" + assert self.defined_in is None + self._defined_in = defined_in + + @property + def aliases(self): + """Returns a list of all full names for the documented object.""" + return self._aliases + + def set_aliases(self, aliases): + """Sets the `aliases` list. + + Args: + aliases: A list of strings. Containing all the obejct's full names. + """ + assert self.aliases is None + self._aliases = aliases + + @property + def doc(self): + """Returns a `_DocstringInfo` created from the object's docstring.""" + return self._doc + + def set_doc(self, doc): + """Sets the `doc` field. + + Args: + doc: An instance of `_DocstringInfo`. + """ + assert self.doc is None + self._doc = doc + + @property + def guides(self): + """Returns a markdown string containing backlinks to relevant api_guides.""" + return self._guides + + def set_guides(self, guides): + """Sets the `guides` field. + + Args: + guides: A markdown string containing backlinks to all the api_guides that + link to the documented object. + """ + assert self.guides is None + self._guides = guides + + @property + def bases(self): + """Returns a list of `_LinkInfo` objects pointing to the class' parents.""" + return self._bases + + def _set_bases(self, relative_path, parser_config): + """Builds the `bases` attribute, to document this class' parent-classes. + + This method sets the `bases` to a list of `_LinkInfo` objects point to the + doc pages for the class' parents. + + Args: + relative_path: The relative path from the doc this object describes to + the documentation root. + parser_config: An instance of `ParserConfig`. + """ + bases = [] + obj = parser_config.py_name_to_object(self.full_name) + for base in obj.__bases__: + base_full_name = parser_config.reverse_index.get(id(base), None) + if base_full_name is None: + continue + base_doc = _parse_md_docstring(base, relative_path, + parser_config.reference_resolver) + base_url = parser_config.reference_resolver.reference_to_url( + base_full_name, relative_path) + + link_info = _LinkInfo(short_name=base_full_name.split('.')[-1], + full_name=base_full_name, obj=base, + doc=base_doc, url=base_url) + bases.append(link_info) + + self._bases = bases + + @property + def properties(self): + """Returns a list of `_PropertyInfo` describing the class' properties.""" + return self._properties + + def _add_property(self, short_name, full_name, obj, doc): + """Adds a `_PropertyInfo` entry to the `properties` list. + + Args: + short_name: The property's short name. + full_name: The property's fully qualified name. + obj: The property object itself + doc: The property's parsed docstring, a `_DocstringInfo`. + """ + property_info = _PropertyInfo(short_name, full_name, obj, doc) + self._properties.append(property_info) + + @property + def methods(self): + """Returns a list of `_MethodInfo` describing the class' methods.""" + return self._methods + + def _add_method(self, short_name, full_name, obj, doc, signature): + """Adds a `_MethodInfo` entry to the `methods` list. + + Args: + short_name: The method's short name. + full_name: The method's fully qualified name. + obj: The method object itself + doc: The method's parsed docstring, a `_DocstringInfo` + signature: The method's parsed signature (see: `_generate_signature`) + """ + method_info = _MethodInfo(short_name, full_name, obj, doc, signature) + self._methods.append(method_info) + + @property + def classes(self): + """Returns a list of `_LinkInfo` pointing to any nested classes.""" + return self._classes + + def _add_class(self, short_name, full_name, obj, doc, url): + """Adds a `_LinkInfo` for a nested class to `classes` list. + + Args: + short_name: The class' short name. + full_name: The class' fully qualified name. + obj: The class object itself + doc: The class' parsed docstring, a `_DocstringInfo` + url: A url pointing to where the nested class is documented. + """ + page_info = _LinkInfo(short_name, full_name, obj, doc, url) + + self._classes.append(page_info) + + @property + def other_members(self): + """Returns a list of `_OtherMemberInfo` describing any other contents.""" + return self._other_members + + def _add_other_member(self, short_name, full_name, obj, doc): + """Adds an `_OtherMemberInfo` entry to the `other_members` list. + + Args: + short_name: The class' short name. + full_name: The class' fully qualified name. + obj: The class object itself + doc: The class' parsed docstring, a `_DocstringInfo` + """ + other_member_info = _OtherMemberInfo(short_name, full_name, obj, doc) + self._other_members.append(other_member_info) + + def collect_docs_for_class(self, py_class, parser_config): + """Collects information necessary specifically for a class's doc page. + + Mainly, this is details about the class's members. + + Args: + py_class: The class object being documented + parser_config: An instance of ParserConfig. + """ + doc_path = documentation_path(self.full_name) + relative_path = os.path.relpath( + path='.', start=os.path.dirname(doc_path) or '.') + + self._set_bases(relative_path, parser_config) + + for short_name in parser_config.tree[self.full_name]: + # Remove builtin members that we never want to document. + if short_name in ['__class__', '__base__', '__weakref__', '__doc__', + '__module__', '__dict__', '__abstractmethods__', + '__slots__', '__getnewargs__']: + continue + + child_name = '.'.join([self.full_name, short_name]) + child = parser_config.py_name_to_object(child_name) + + # Don't document anything that is defined in object or by protobuf. + defining_class = _get_defining_class(py_class, short_name) + if (defining_class is object or + defining_class is type or defining_class is tuple or + defining_class is BaseException or defining_class is Exception or + # The following condition excludes most protobuf-defined symbols. + defining_class and defining_class.__name__ in ['CMessage', 'Message', + 'MessageMeta']): + continue + # TODO(markdaoust): Add a note in child docs showing the defining class. + + child_doc = _parse_md_docstring(child, relative_path, + parser_config.reference_resolver) + + if isinstance(child, property): + self._add_property(short_name, child_name, child, child_doc) + + elif tf_inspect.isclass(child): + if defining_class is None: + continue + url = parser_config.reference_resolver.reference_to_url( + child_name, relative_path) + self._add_class(short_name, child_name, child, child_doc, url) + + elif (tf_inspect.ismethod(child) or tf_inspect.isfunction(child) or + tf_inspect.isroutine(child)): + if defining_class is None: + continue + + # Omit methods defined by namedtuple. + original_method = defining_class.__dict__[short_name] + if (hasattr(original_method, '__module__') and + (original_method.__module__ or '').startswith('namedtuple')): + continue + + # Some methods are often overridden without documentation. Because it's + # obvious what they do, don't include them in the docs if there's no + # docstring. + if not child_doc.brief.strip() and short_name in [ + '__str__', '__repr__', '__hash__', '__del__', '__copy__']: + print('Skipping %s, defined in %s, no docstring.' % (child_name, + defining_class)) + continue + + try: + child_signature = _generate_signature(child, + parser_config.reverse_index) + except TypeError: + # If this is a (dynamically created) slot wrapper, tf_inspect will + # raise typeerror when trying to get to the code. Ignore such + # functions. + continue + + self._add_method(short_name, child_name, child, child_doc, + child_signature) + else: + # Exclude members defined by protobuf that are useless + if issubclass(py_class, ProtoMessage): + if (short_name.endswith('_FIELD_NUMBER') or + short_name in ['__slots__', 'DESCRIPTOR']): + continue + + # TODO(wicke): We may want to also remember the object itself. + self._add_other_member(short_name, child_name, child, child_doc) -def _generate_markdown_for_class(full_name, duplicate_names, py_class, - duplicate_of, index, tree): - """Generate Markdown docs for a class. +class _ModulePageInfo(object): + """Collects docs for a module page.""" - This function creates a documentation page for a class. It uses the - class name as the title, followed by a list of duplicate - names (if there are any), the Markdown formatted docstring of the - class, a list of links to all child class docs, a list of all properties - including their docstrings, a list of all methods incl. their docstrings, and - a list of all class member names (public fields). + def __init__(self, full_name): + self._full_name = full_name + self._defined_in = None + self._aliases = None + self._doc = None + self._guides = None - Args: - full_name: The preferred name of the class. Used in the title. Must not - be present in `duplicate_of` (master names never are). - duplicate_names: A sorted list of alternative names (incl. `full_name`). - py_class: The python object referenced by `full_name`. - duplicate_of: A map of duplicate full names to master names. Used to resolve - @{symbol} references in the docstrings. - index: A map from full names to python object references. - tree: A map from full names to the names of all documentable child objects. + self._modules = [] + self._classes = [] + self._functions = [] + self._other_members = [] - Returns: - A string that can be written to a documentation file for this class. - """ - relative_path = os.path.relpath( - os.path.dirname(documentation_path(full_name)) or '.', '.') - docstring = _md_docstring(py_class, relative_path, duplicate_of) - if duplicate_names: - aliases = '\n'.join(['### `class %s`' % name for name in duplicate_names]) - aliases += '\n\n' - else: - aliases = '' + def for_function(self): + return False - docs = '# `%s`\n\n%s%s\n\n' % (full_name, aliases, docstring) + def for_class(self): + return False - field_names = [] - properties = [] - methods = [] - class_links = [] - for member in tree[full_name]: - child_name = '.'.join([full_name, member]) - child = index[child_name] + def for_module(self): + return True - if isinstance(child, property): - properties.append((member, child)) - elif inspect.isclass(child): - class_links.append(_markdown_link('class ' + member, child_name, - relative_path, duplicate_of)) - elif inspect.ismethod(child) or inspect.isfunction(child): - methods.append((member, child)) - else: - # TODO(wicke): We may want to also remember the object itself. - field_names.append(member) + @property + def full_name(self): + return self._full_name - if class_links: - docs += '## Child Classes\n' - docs += '\n\n'.join(sorted(class_links)) - docs += '\n\n' + @property + def short_name(self): + return self._full_name.split('.')[-1] - if properties: - docs += '## Properties\n\n' - for property_name, prop in sorted(properties, key=lambda x: x[0]): - docs += '### `%s`\n\n%s\n\n' % ( - property_name, _md_docstring(prop, relative_path, duplicate_of)) - docs += '\n\n' + @property + def defined_in(self): + return self._defined_in - if methods: - docs += '## Methods\n\n' - for method_name, method in sorted(methods, key=lambda x: x[0]): - method_signature = method_name + _generate_signature(method) - docs += '### `%s`\n\n%s\n\n' % (method_signature, - _md_docstring(method, relative_path, - duplicate_of)) - docs += '\n\n' + def set_defined_in(self, defined_in): + assert self.defined_in is None + self._defined_in = defined_in - if field_names: - docs += '## Class Members\n\n' - # TODO(wicke): Document the value of the members, at least for basic types. - docs += '\n\n'.join(sorted(field_names)) - docs += '\n\n' + @property + def aliases(self): + return self._aliases - return docs + def set_aliases(self, aliases): + assert self.aliases is None + self._aliases = aliases + + @property + def doc(self): + return self._doc + + def set_doc(self, doc): + assert self.doc is None + self._doc = doc + + @property + def guides(self): + return self._guides + + def set_guides(self, guides): + assert self.guides is None + self._guides = guides + + @property + def modules(self): + return self._modules + + def _add_module(self, short_name, full_name, obj, doc, url): + self._modules.append(_LinkInfo(short_name, full_name, obj, doc, url)) + + @property + def classes(self): + return self._classes + + def _add_class(self, short_name, full_name, obj, doc, url): + self._classes.append(_LinkInfo(short_name, full_name, obj, doc, url)) + + @property + def functions(self): + return self._functions + + def _add_function(self, short_name, full_name, obj, doc, url): + self._functions.append(_LinkInfo(short_name, full_name, obj, doc, url)) + + @property + def other_members(self): + return self._other_members + + def _add_other_member(self, short_name, full_name, obj, doc): + self._other_members.append( + _OtherMemberInfo(short_name, full_name, obj, doc)) + + def collect_docs_for_module(self, parser_config): + """Collect information necessary specifically for a module's doc page. + + Mainly this is information about the members of the module. + + Args: + parser_config: An instance of ParserConfig. + """ + relative_path = os.path.relpath( + path='.', + start=os.path.dirname(documentation_path(self.full_name)) or '.') + + member_names = parser_config.tree.get(self.full_name, []) + for name in member_names: + + if name in ['__builtins__', '__doc__', '__file__', + '__name__', '__path__', '__package__']: + continue + + member_full_name = self.full_name + '.' + name if self.full_name else name + member = parser_config.py_name_to_object(member_full_name) + + member_doc = _parse_md_docstring(member, relative_path, + parser_config.reference_resolver) + + url = parser_config.reference_resolver.reference_to_url( + member_full_name, relative_path) + + if tf_inspect.ismodule(member): + self._add_module(name, member_full_name, member, member_doc, url) + + elif tf_inspect.isclass(member): + self._add_class(name, member_full_name, member, member_doc, url) + + elif tf_inspect.isfunction(member): + self._add_function(name, member_full_name, member, member_doc, url) + + else: + self._add_other_member(name, member_full_name, member, member_doc) -def _generate_markdown_for_module(full_name, duplicate_names, module, - duplicate_of, index, tree): - """Generate Markdown docs for a module. +class ParserConfig(object): + """Stores all indexes required to parse the docs.""" - This function creates a documentation page for a module. It uses the - module name as the title, followed by a list of duplicate - names (if there are any), the Markdown formatted docstring of the - class, and a list of links to all members of this module. + def __init__(self, reference_resolver, duplicates, duplicate_of, tree, index, + reverse_index, guide_index, base_dir): + """Object with the common config for docs_for_object() calls. - Args: - full_name: The preferred name of the module. Used in the title. Must not - be present in `duplicate_of` (master names never are). - duplicate_names: A sorted list of alternative names (incl. `full_name`). - module: The python object referenced by `full_name`. - duplicate_of: A map of duplicate full names to master names. Used to resolve - @{symbol} references in the docstrings. - index: A map from full names to python object references. - tree: A map from full names to the names of all documentable child objects. + Args: + reference_resolver: An instance of ReferenceResolver. + duplicates: A `dict` mapping fully qualified names to a set of all + aliases of this name. This is used to automatically generate a list of + all aliases for each name. + duplicate_of: A map from duplicate names to preferred names of API + symbols. + tree: A `dict` mapping a fully qualified name to the names of all its + members. Used to populate the members section of a class or module page. + index: A `dict` mapping full names to objects. + reverse_index: A `dict` mapping object ids to full names. - Returns: - A string that can be written to a documentation file for this module. - """ - relative_path = os.path.relpath( - os.path.dirname(documentation_path(full_name)) or '.', '.') - docstring = _md_docstring(module, relative_path, duplicate_of) - if duplicate_names: - aliases = '\n'.join(['### Module `%s`' % name for name in duplicate_names]) - aliases += '\n\n' - else: - aliases = '' + guide_index: A `dict` mapping symbol name strings to objects with a + `make_md_link()` method. - member_names = tree.get(full_name, []) + base_dir: A base path that is stripped from file locations written to the + docs. + """ + self.reference_resolver = reference_resolver + self.duplicates = duplicates + self.duplicate_of = duplicate_of + self.tree = tree + self.reverse_index = reverse_index + self.index = index + self.guide_index = guide_index + self.base_dir = base_dir + self.defined_in_prefix = 'tensorflow/' + self.code_url_prefix = ( + 'https://www.tensorflow.org/code/tensorflow/') # pylint: disable=line-too-long - # Make links to all members. - member_links = [] - for name in member_names: - member_full_name = full_name + '.' + name if full_name else name - member = index[member_full_name] - - if inspect.isclass(member): - link_text = 'class ' + name - elif inspect.isfunction(member): - link_text = name + _generate_signature(member) - else: - link_text = name - - member_links.append(_markdown_link(link_text, member_full_name, - relative_path, duplicate_of)) - - # TODO(deannarubin): Make this list into a table and add the brief docstring. - # (use _get_brief_docstring) - - return '# Module `%s`\n\n%s%s\n\n## Members\n\n%s' % ( - full_name, aliases, docstring, '\n\n'.join(member_links)) + def py_name_to_object(self, full_name): + """Return the Python object for a Python symbol name.""" + return self.index[full_name] -_CODE_URL_PREFIX = ( - 'https://www.tensorflow.org/code/') +def docs_for_object(full_name, py_object, parser_config): + """Return a PageInfo object describing a given object from the TF API. - -def generate_markdown(full_name, py_object, - duplicate_of, duplicates, - index, tree, base_dir): - """Generate Markdown docs for a given object that's part of the TF API. - - This function uses _md_docstring to obtain the docs pertaining to + This function uses _parse_md_docstring to parse the docs pertaining to `object`. - This function resolves '@symbol' references in the docstrings into links to + This function resolves '@{symbol}' references in the docstrings into links to the appropriate location. It also adds a list of alternative names for the symbol automatically. @@ -538,28 +1296,16 @@ def generate_markdown(full_name, py_object, `documentation_path`, and that relative links to files within the documentation are resolvable. - The output is Markdown that can be written to file and published. - Args: - full_name: The fully qualified name (excl. "tf.") of the symbol to be + full_name: The fully qualified name of the symbol to be documented. py_object: The Python object to be documented. Its documentation is sourced from `py_object`'s docstring. - duplicate_of: A `dict` mapping fully qualified names to "master" names. This - is used to resolve "@{symbol}" references to the "master" name. - duplicates: A `dict` mapping fully qualified names to a set of all - aliases of this name. This is used to automatically generate a list of all - aliases for each name. - index: A `dict` mapping fully qualified names to the corresponding Python - objects. Used to produce docs for child objects, and to check the validity - of "@{symbol}" references. - tree: A `dict` mapping a fully qualified name to the names of all its - members. Used to populate the members section of a class or module page. - base_dir: A base path that is stripped from file locations written to the - docs. + parser_config: A ParserConfig object. Returns: - A string containing the Markdown docs for `py_object`. + Either a `_FunctionPageInfo`, `_ClassPageInfo`, or a `_ModulePageInfo` + depending on the type of the python object being documented. Raises: RuntimeError: If an object is encountered for which we don't know how @@ -567,47 +1313,183 @@ def generate_markdown(full_name, py_object, """ # Which other aliases exist for the object referenced by full_name? - master_name = duplicate_of.get(full_name, full_name) - duplicate_names = duplicates.get(master_name, [full_name]) + master_name = parser_config.reference_resolver.py_master_name(full_name) + duplicate_names = parser_config.duplicates.get(master_name, [full_name]) # TODO(wicke): Once other pieces are ready, enable this also for partials. - if (inspect.ismethod(py_object) or inspect.isfunction(py_object) or + if (tf_inspect.ismethod(py_object) or tf_inspect.isfunction(py_object) or # Some methods in classes from extensions come in as routines. - inspect.isroutine(py_object)): - markdown = _generate_markdown_for_function(master_name, duplicate_names, - py_object, duplicate_of) - elif inspect.isclass(py_object): - markdown = _generate_markdown_for_class(master_name, duplicate_names, - py_object, duplicate_of, - index, tree) - elif inspect.ismodule(py_object): - markdown = _generate_markdown_for_module(master_name, duplicate_names, - py_object, duplicate_of, - index, tree) + tf_inspect.isroutine(py_object)): + page_info = _FunctionPageInfo(master_name) + page_info.set_signature(py_object, parser_config.reverse_index) + + elif tf_inspect.isclass(py_object): + page_info = _ClassPageInfo(master_name) + page_info.collect_docs_for_class(py_object, parser_config) + + elif tf_inspect.ismodule(py_object): + page_info = _ModulePageInfo(master_name) + page_info.collect_docs_for_module(parser_config) + else: raise RuntimeError('Cannot make docs for object %s: %r' % (full_name, py_object)) - # Every page gets a note on the bottom about where this object is defined + relative_path = os.path.relpath( + path='.', start=os.path.dirname(documentation_path(full_name)) or '.') + + page_info.set_doc(_parse_md_docstring( + py_object, relative_path, parser_config.reference_resolver)) + + page_info.set_aliases(duplicate_names) + + page_info.set_guides(_get_guides_markdown( + duplicate_names, parser_config.guide_index, relative_path)) + + page_info.set_defined_in(_get_defined_in(py_object, parser_config)) + + return page_info + + +class _PythonBuiltin(object): + """This class indicated that the object in question is a python builtin. + + This can be used for the `defined_in` slot of the `PageInfo` objects. + """ + + def is_builtin(self): + return True + + def is_python_file(self): + return False + + def is_generated_file(self): + return False + + def __str__(self): + return 'This is an alias for a Python built-in.\n\n' + + +class _PythonFile(object): + """This class indicates that the object is defined in a regular python file. + + This can be used for the `defined_in` slot of the `PageInfo` obejcts. + """ + + def __init__(self, path, parser_config): + self.path = path + self.path_prefix = parser_config.defined_in_prefix + self.code_url_prefix = parser_config.code_url_prefix + + def is_builtin(self): + return False + + def is_python_file(self): + return True + + def is_generated_file(self): + return False + + def __str__(self): + return 'Defined in [`{prefix}{path}`]({code_prefix}{path}).\n\n'.format( + path=self.path, prefix=self.path_prefix, + code_prefix=self.code_url_prefix) + + +class _ProtoFile(object): + """This class indicates that the object is defined in a .proto file. + + This can be used for the `defined_in` slot of the `PageInfo` objects. + """ + + def __init__(self, path, parser_config): + self.path = path + self.path_prefix = parser_config.defined_in_prefix + self.code_url_prefix = parser_config.code_url_prefix + + def is_builtin(self): + return False + + def is_python_file(self): + return False + + def is_generated_file(self): + return False + + def __str__(self): + return 'Defined in [`{prefix}{path}`]({code_prefix}{path}).\n\n'.format( + path=self.path, prefix=self.path_prefix, + code_prefix=self.code_url_prefix) + + +class _GeneratedFile(object): + """This class indicates that the object is defined in a generated python file. + + Generated files should not be linked to directly. + + This can be used for the `defined_in` slot of the `PageInfo` objects. + """ + + def __init__(self, path, parser_config): + self.path = path + self.path_prefix = parser_config.defined_in_prefix + + def is_builtin(self): + return False + + def is_python_file(self): + return False + + def is_generated_file(self): + return True + + def __str__(self): + return 'Defined in `%s%s`.\n\n' % (self.path_prefix, self.path) + + +def _get_defined_in(py_object, parser_config): + """Returns a description of where the passed in python object was defined. + + Arguments: + py_object: The Python object. + parser_config: A ParserConfig object. + + Returns: + Either a `_PythonBuiltin`, `_PythonFile`, or a `_GeneratedFile` + """ + # Every page gets a note about where this object is defined # TODO(wicke): If py_object is decorated, get the decorated object instead. # TODO(wicke): Only use decorators that support this in TF. try: - path = os.path.relpath(inspect.getfile(py_object), base_dir) - - # TODO(wicke): If this is a generated file, point to the source instead. - - # Never include links outside this code base. - if not path.startswith('..'): - markdown += '\n\nDefined in [`%s`](%s%s).\n\n' % ( - path, _CODE_URL_PREFIX, path) + path = os.path.relpath(path=tf_inspect.getfile(py_object), + start=parser_config.base_dir) except TypeError: # getfile throws TypeError if py_object is a builtin. - markdown += '\n\nThis is an alias for a Python built-in.' + return _PythonBuiltin() - return markdown + # TODO(wicke): If this is a generated file, link to the source instead. + # TODO(wicke): Move all generated files to a generated/ directory. + # TODO(wicke): And make their source file predictable from the file name. + + # In case this is compiled, point to the original + if path.endswith('.pyc'): + path = path[:-1] + + # Never include links outside this code base. + if path.startswith('..'): + return None + + if re.match(r'.*/gen_[^/]*\.py$', path): + return _GeneratedFile(path, parser_config) + elif re.match(r'.*_pb2\.py$', path): + # The _pb2.py files all appear right next to their defining .proto file. + return _ProtoFile(path[:-7] + '.proto', parser_config) + else: + return _PythonFile(path, parser_config) -def generate_global_index(library_name, root_name, index, duplicate_of): +# TODO(markdaoust): This should just parse, pretty_docs should generate the md. +def generate_global_index(library_name, index, reference_resolver): """Given a dict of full names to python objects, generate an index page. The index page generated contains a list of links for all symbols in `index` @@ -615,37 +1497,31 @@ def generate_global_index(library_name, root_name, index, duplicate_of): Args: library_name: The name for the documented library to use in the title. - root_name: The name to use for the root module. index: A dict mapping full names to python objects. - duplicate_of: A map of duplicate names to preferred names. + reference_resolver: An instance of ReferenceResolver. Returns: A string containing an index page as Markdown. """ symbol_links = [] for full_name, py_object in six.iteritems(index): - index_name = full_name or root_name - if (inspect.ismodule(py_object) or inspect.isfunction(py_object) or - inspect.isclass(py_object)): + if (tf_inspect.ismodule(py_object) or tf_inspect.isfunction(py_object) or + tf_inspect.isclass(py_object)): # In Python 3, unbound methods are functions, so eliminate those. - if inspect.isfunction(py_object): + if tf_inspect.isfunction(py_object): if full_name.count('.') == 0: parent_name = '' else: parent_name = full_name[:full_name.rfind('.')] - if parent_name in index and inspect.isclass(index[parent_name]): + if parent_name in index and tf_inspect.isclass(index[parent_name]): # Skip methods (=functions with class parents). continue - - symbol_links.append((index_name, - _markdown_link(index_name, full_name, - '.', duplicate_of))) + symbol_links.append(( + full_name, reference_resolver.python_link(full_name, full_name, '.'))) lines = ['# All symbols in %s' % library_name, ''] for _, link in sorted(symbol_links, key=lambda x: x[0]): lines.append('* %s' % link) - # TODO(deannarubin): Make this list into a table and add the brief docstring. - # (use _get_brief_docstring) - + # TODO(markdaoust): use a _ModulePageInfo -> prety_docs.build_md_page() return '\n'.join(lines) diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 521e2d4ed3b..3e02160130f 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -19,19 +19,14 @@ from __future__ import division from __future__ import print_function import functools -import inspect import os import sys from tensorflow.python.platform import googletest +from tensorflow.python.util import tf_inspect from tensorflow.tools.docs import parser -def test_function_for_markdown_reference(unused_arg): - """Docstring with reference to @{test_function}.""" - pass - - def test_function(unused_arg, unused_kwarg='default'): """Docstring for test function.""" pass @@ -42,19 +37,6 @@ def test_function_with_args_kwargs(unused_arg, *unused_args, **unused_kwargs): pass -def test_function_with_fancy_docstring(arg): - """Function with a fancy docstring. - - Args: - arg: An argument. - - Returns: - arg: the input, and - arg: the input, again. - """ - return arg, arg - - class TestClass(object): """Docstring for TestClass itself.""" @@ -74,26 +56,70 @@ class TestClass(object): CLASS_MEMBER = 'a class member' +class DummyVisitor(object): + + def __init__(self, index, duplicate_of): + self.index = index + self.duplicate_of = duplicate_of + + class ParserTest(googletest.TestCase): def test_documentation_path(self): self.assertEqual('test.md', parser.documentation_path('test')) self.assertEqual('test/module.md', parser.documentation_path('test.module')) - def test_documentation_path_empty(self): - self.assertEqual('index.md', parser.documentation_path('')) - def test_replace_references(self): - string = 'A @{reference}, another @{tf.reference}, and a @{third}.' - duplicate_of = {'third': 'fourth'} - result = parser.replace_references(string, '../..', duplicate_of) + class HasOneMember(object): + + def foo(self): + pass + + string = ('A @{tf.reference}, another @{tf.reference}, ' + 'a member @{tf.reference.foo}, and a @{tf.third}.') + duplicate_of = {'tf.third': 'tf.fourth'} + index = {'tf.reference': HasOneMember, + 'tf.reference.foo': HasOneMember.foo, + 'tf.third': HasOneMember, + 'tf.fourth': HasOneMember} + + visitor = DummyVisitor(index, duplicate_of) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + + result = reference_resolver.replace_references(string, '../..') self.assertEqual( - 'A [`reference`](../../reference.md), another ' - '[`tf.reference`](../../reference.md), ' - 'and a [`third`](../../fourth.md).', + 'A [`tf.reference`](../../tf/reference.md), another ' + '[`tf.reference`](../../tf/reference.md), ' + 'a member [`tf.reference.foo`](../../tf/reference.md#foo), ' + 'and a [`tf.third`](../../tf/fourth.md).', result) - def test_generate_markdown_for_class(self): + def test_doc_replace_references(self): + string = '@{$doc1} @{$doc1#abc} @{$doc1$link} @{$doc1#def$zelda} @{$do/c2}' + + class DocInfo(object): + pass + doc1 = DocInfo() + doc1.title = 'Title1' + doc1.url = 'URL1' + doc2 = DocInfo() + doc2.title = 'Two words' + doc2.url = 'somewhere/else' + doc_index = {'doc1': doc1, 'do/c2': doc2} + + visitor = DummyVisitor(index={}, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index=doc_index, py_module_names=['tf']) + result = reference_resolver.replace_references(string, 'python') + self.assertEqual( + '[Title1](../URL1) [Title1](../URL1#abc) [link](../URL1) ' + '[zelda](../URL1#def) [Two words](../somewhere/else)', + result) + + def test_docs_for_class(self): index = { 'TestClass': TestClass, @@ -103,32 +129,48 @@ class ParserTest(googletest.TestCase): 'TestClass.CLASS_MEMBER': TestClass.CLASS_MEMBER } + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + tree = { 'TestClass': ['a_method', 'a_property', 'ChildClass', 'CLASS_MEMBER'] } + parser_config = parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') - docs = parser.generate_markdown(full_name='TestClass', py_object=TestClass, - duplicate_of={}, duplicates={}, - index=index, tree=tree, base_dir='/') + page_info = parser.docs_for_object( + full_name='TestClass', py_object=TestClass, parser_config=parser_config) - # Make sure all required docstrings are present. - self.assertTrue(inspect.getdoc(TestClass) in docs) - self.assertTrue(inspect.getdoc(TestClass.a_method) in docs) - self.assertTrue(inspect.getdoc(TestClass.a_property) in docs) + # Make sure the brief docstring is present + self.assertEqual( + tf_inspect.getdoc(TestClass).split('\n')[0], page_info.doc.brief) + + # Make sure the method is present + self.assertEqual(TestClass.a_method, page_info.methods[0].obj) # Make sure that the signature is extracted properly and omits self. - self.assertTrue('a_method(arg=\'default\')' in docs) + self.assertEqual(["arg='default'"], page_info.methods[0].signature) + + # Make sure the property is present + self.assertIs(TestClass.a_property, page_info.properties[0].obj) # Make sure there is a link to the child class and it points the right way. - self.assertTrue('[`class ChildClass`](./TestClass/ChildClass.md)' in docs) - - # Make sure CLASS_MEMBER is mentioned. - self.assertTrue('CLASS_MEMBER' in docs) + self.assertIs(TestClass.ChildClass, page_info.classes[0].obj) # Make sure this file is contained as the definition location. - self.assertTrue(os.path.relpath(__file__, '/') in docs) + self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path) - def test_generate_markdown_for_module(self): + def test_docs_for_module(self): + # Get the current module. module = sys.modules[__name__] index = { @@ -139,127 +181,180 @@ class ParserTest(googletest.TestCase): 'TestModule.TestClass': TestClass, } + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + tree = { 'TestModule': ['TestClass', 'test_function', 'test_function_with_args_kwargs'] } + parser_config = parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') - docs = parser.generate_markdown(full_name='TestModule', py_object=module, - duplicate_of={}, duplicates={}, - index=index, tree=tree, base_dir='/') + page_info = parser.docs_for_object( + full_name='TestModule', py_object=module, parser_config=parser_config) - # Make sure all required docstrings are present. - self.assertTrue(inspect.getdoc(module) in docs) + # Make sure the brief docstring is present + self.assertEqual(tf_inspect.getdoc(module).split('\n')[0], + page_info.doc.brief) - # Make sure that links to the members are there (not asserting on exact link - # text for functions). - self.assertTrue('./TestModule/test_function.md' in docs) - self.assertTrue('./TestModule/test_function_with_args_kwargs.md' in docs) + # Make sure that the members are there + funcs = {f_info.obj for f_info in page_info.functions} + self.assertEqual({test_function, test_function_with_args_kwargs}, funcs) - # Make sure there is a link to the child class and it points the right way. - self.assertTrue('[`class TestClass`](./TestModule/TestClass.md)' in docs) + classes = {cls_info.obj for cls_info in page_info.classes} + self.assertEqual({TestClass}, classes) # Make sure this file is contained as the definition location. - self.assertTrue(os.path.relpath(__file__, '/') in docs) + self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path) - def test_generate_markdown_for_function(self): + def test_docs_for_function(self): index = { 'test_function': test_function } + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + tree = { '': ['test_function'] } + parser_config = parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') - docs = parser.generate_markdown(full_name='test_function', - py_object=test_function, - duplicate_of={}, duplicates={}, - index=index, tree=tree, base_dir='/') + page_info = parser.docs_for_object( + full_name='test_function', + py_object=test_function, + parser_config=parser_config) - # Make sure docstring shows up. - self.assertTrue(inspect.getdoc(test_function) in docs) + # Make sure the brief docstring is present + self.assertEqual( + tf_inspect.getdoc(test_function).split('\n')[0], page_info.doc.brief) # Make sure the extracted signature is good. - self.assertTrue( - 'test_function(unused_arg, unused_kwarg=\'default\')' in docs) + self.assertEqual(['unused_arg', "unused_kwarg='default'"], + page_info.signature) # Make sure this file is contained as the definition location. - self.assertTrue(os.path.relpath(__file__, '/') in docs) + self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path) - def test_generate_markdown_for_function_with_kwargs(self): + def test_docs_for_function_with_kwargs(self): index = { 'test_function_with_args_kwargs': test_function_with_args_kwargs } + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + tree = { '': ['test_function_with_args_kwargs'] } + parser_config = parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') - docs = parser.generate_markdown(full_name='test_function_with_args_kwargs', - py_object=test_function_with_args_kwargs, - duplicate_of={}, duplicates={}, - index=index, tree=tree, base_dir='/') + page_info = parser.docs_for_object( + full_name='test_function_with_args_kwargs', + py_object=test_function_with_args_kwargs, + parser_config=parser_config) - # Make sure docstring shows up. - self.assertTrue(inspect.getdoc(test_function_with_args_kwargs) in docs) + # Make sure the brief docstring is present + self.assertEqual( + tf_inspect.getdoc(test_function_with_args_kwargs).split('\n')[0], + page_info.doc.brief) # Make sure the extracted signature is good. - self.assertTrue( - 'test_function_with_args_kwargs(unused_arg,' - ' *unused_args, **unused_kwargs)' in docs) + self.assertEqual(['unused_arg', '*unused_args', '**unused_kwargs'], + page_info.signature) - def test_references_replaced_in_generated_markdown(self): + def test_parse_md_docstring(self): + + def test_function_with_fancy_docstring(arg): + """Function with a fancy docstring. + + And a bunch of references: @{tf.reference}, another @{tf.reference}, + a member @{tf.reference.foo}, and a @{tf.third}. + + Args: + arg: An argument. + + Raises: + an exception + + Returns: + arg: the input, and + arg: the input, again. + + @compatibility(numpy) + NumPy has nothing as awesome as this function. + @end_compatibility + + @compatibility(theano) + Theano has nothing as awesome as this function. + + Check it out. + @end_compatibility + + """ + return arg, arg + + class HasOneMember(object): + + def foo(self): + pass + + duplicate_of = {'tf.third': 'tf.fourth'} index = { - 'test_function_for_markdown_reference': - test_function_for_markdown_reference + 'tf.fancy': test_function_with_fancy_docstring, + 'tf.reference': HasOneMember, + 'tf.reference.foo': HasOneMember.foo, + 'tf.third': HasOneMember, + 'tf.fourth': HasOneMember } - tree = { - '': ['test_function_for_markdown_reference'] - } + visitor = DummyVisitor(index=index, duplicate_of=duplicate_of) - docs = parser.generate_markdown( - full_name='test_function_for_markdown_reference', - py_object=test_function_for_markdown_reference, - duplicate_of={}, duplicates={}, - index=index, tree=tree, base_dir='/') + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) - # Make sure docstring shows up and is properly processed. - expected_docs = parser.replace_references( - inspect.getdoc(test_function_for_markdown_reference), - relative_path_to_root='.', duplicate_of={}) + doc_info = parser._parse_md_docstring(test_function_with_fancy_docstring, + '../..', reference_resolver) - self.assertTrue(expected_docs in docs) + self.assertNotIn('@', doc_info.docstring) + self.assertNotIn('compatibility', doc_info.docstring) + self.assertNotIn('Raises:', doc_info.docstring) - def test_docstring_special_section(self): - index = { - 'test_function': test_function_with_fancy_docstring - } + self.assertEqual(len(doc_info.function_details), 3) + self.assertEqual(set(doc_info.compatibility.keys()), {'numpy', 'theano'}) - tree = { - '': 'test_function' - } - - docs = parser.generate_markdown( - full_name='test_function', - py_object=test_function_with_fancy_docstring, - duplicate_of={}, duplicates={}, - index=index, tree=tree, base_dir='/') - - expected = '\n'.join([ - 'Function with a fancy docstring.', - '', - '#### Args:', - '', - '* `arg`: An argument.', - '', - '', - '#### Returns:', - '', - '* `arg`: the input, and', - '* `arg`: the input, again.', - '']) - self.assertTrue(expected in docs) + self.assertEqual(doc_info.compatibility['numpy'], + 'NumPy has nothing as awesome as this function.\n') def test_generate_index(self): module = sys.modules[__name__] @@ -273,27 +368,30 @@ class ParserTest(googletest.TestCase): 'TestModule.TestClass.a_property': TestClass.a_property, 'TestModule.TestClass.ChildClass': TestClass.ChildClass, } - duplicate_of = { 'TestModule.test_function': 'test_function' } - docs = parser.generate_global_index('TestLibrary', 'test', - index=index, - duplicate_of=duplicate_of) + visitor = DummyVisitor(index=index, duplicate_of=duplicate_of) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + + docs = parser.generate_global_index('TestLibrary', index=index, + reference_resolver=reference_resolver) # Make sure duplicates and non-top-level symbols are in the index, but # methods and properties are not. - self.assertTrue('a_method' not in docs) - self.assertTrue('a_property' not in docs) - self.assertTrue('TestModule.TestClass' in docs) - self.assertTrue('TestModule.TestClass.ChildClass' in docs) - self.assertTrue('TestModule.test_function' in docs) + self.assertNotIn('a_method', docs) + self.assertNotIn('a_property', docs) + self.assertIn('TestModule.TestClass', docs) + self.assertIn('TestModule.TestClass.ChildClass', docs) + self.assertIn('TestModule.test_function', docs) # Leading backtick to make sure it's included top-level. # This depends on formatting, but should be stable. - self.assertTrue('`test_function' in docs) + self.assertIn('`test_function', docs) - def test_argspec_for_functoos_partial(self): + def test_argspec_for_functools_partial(self): # pylint: disable=unused-argument def test_function_for_partial1(arg1, arg2, kwarg1=1, kwarg2=2): @@ -305,45 +403,117 @@ class ParserTest(googletest.TestCase): # pylint: disable=protected-access # Make sure everything works for regular functions. - expected = inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None, None, - (1, 2)) + expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None, + None, (1, 2)) self.assertEqual(expected, parser._get_arg_spec(test_function_for_partial1)) # Make sure doing nothing works. - expected = inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None, None, - (1, 2)) + expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None, + None, (1, 2)) partial = functools.partial(test_function_for_partial1) self.assertEqual(expected, parser._get_arg_spec(partial)) # Make sure setting args from the front works. - expected = inspect.ArgSpec(['arg2', 'kwarg1', 'kwarg2'], None, None, (1, 2)) + expected = tf_inspect.ArgSpec(['arg2', 'kwarg1', 'kwarg2'], None, None, + (1, 2)) partial = functools.partial(test_function_for_partial1, 1) self.assertEqual(expected, parser._get_arg_spec(partial)) - expected = inspect.ArgSpec(['kwarg2',], None, None, (2,)) + expected = tf_inspect.ArgSpec(['kwarg2',], None, None, (2,)) partial = functools.partial(test_function_for_partial1, 1, 2, 3) self.assertEqual(expected, parser._get_arg_spec(partial)) # Make sure setting kwargs works. - expected = inspect.ArgSpec(['arg1', 'arg2', 'kwarg2'], None, None, (2,)) + expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg2'], None, None, (2,)) partial = functools.partial(test_function_for_partial1, kwarg1=0) self.assertEqual(expected, parser._get_arg_spec(partial)) - expected = inspect.ArgSpec(['arg1', 'arg2', 'kwarg1'], None, None, (1,)) + expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1'], None, None, (1,)) partial = functools.partial(test_function_for_partial1, kwarg2=0) self.assertEqual(expected, parser._get_arg_spec(partial)) - expected = inspect.ArgSpec(['arg1'], None, None, ()) + expected = tf_inspect.ArgSpec(['arg1'], None, None, ()) partial = functools.partial(test_function_for_partial1, arg2=0, kwarg1=0, kwarg2=0) self.assertEqual(expected, parser._get_arg_spec(partial)) # Make sure *args, *kwargs is accounted for. - expected = inspect.ArgSpec([], 'my_args', 'my_kwargs', ()) + expected = tf_inspect.ArgSpec([], 'my_args', 'my_kwargs', ()) partial = functools.partial(test_function_for_partial2, 0, 1) self.assertEqual(expected, parser._get_arg_spec(partial)) # pylint: enable=protected-access + def testSaveReferenceResolver(self): + you_cant_serialize_this = object() + + duplicate_of = {'AClass': ['AClass2']} + doc_index = {'doc': you_cant_serialize_this} + is_class = { + 'tf': False, + 'tf.AClass': True, + 'tf.AClass2': True, + 'tf.function': False + } + is_module = { + 'tf': True, + 'tf.AClass': False, + 'tf.AClass2': False, + 'tf.function': False + } + py_module_names = ['tf', 'tfdbg'] + + resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_class, + is_module, py_module_names) + + outdir = googletest.GetTempDir() + + filepath = os.path.join(outdir, 'resolver.json') + + resolver.to_json_file(filepath) + resolver2 = parser.ReferenceResolver.from_json_file(filepath, doc_index) + + # There are no __slots__, so all fields are visible in __dict__. + self.assertEqual(resolver.__dict__, resolver2.__dict__) + +RELU_DOC = """Computes rectified linear: `max(features, 0)` + +Args: + features: A `Tensor`. Must be one of the following types: `float32`, + `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`, `uint16`, + `half`. + name: A name for the operation (optional) + +Returns: + A `Tensor`. Has the same type as `features` +""" + + +class TestParseFunctionDetails(googletest.TestCase): + + def testParseFunctionDetails(self): + docstring, function_details = parser._parse_function_details(RELU_DOC) + + self.assertEqual(len(function_details), 2) + args = function_details[0] + self.assertEqual(args.keyword, 'Args') + self.assertEmpty(args.header) + self.assertEqual(len(args.items), 2) + self.assertEqual(args.items[0][0], 'features') + self.assertEqual(args.items[1][0], 'name') + self.assertEqual(args.items[1][1], + ' A name for the operation (optional)\n\n') + returns = function_details[1] + self.assertEqual(returns.keyword, 'Returns') + + relu_doc_lines = RELU_DOC.split('\n') + self.assertEqual(docstring, relu_doc_lines[0] + '\n\n') + self.assertEqual(returns.header, relu_doc_lines[-2] + '\n') + + self.assertEqual( + RELU_DOC, + docstring + ''.join(str(detail) for detail in function_details)) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py new file mode 100644 index 00000000000..365008c3f09 --- /dev/null +++ b/tensorflow/tools/docs/pretty_docs.py @@ -0,0 +1,344 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A module for converting parsed doc content into markdown pages. + +The adjacent `parser` module creates `PageInfo` objects, containing all data +necessary to document an element of the TensorFlow API. + +This module contains one public function, which handels the conversion of these +`PageInfo` objects into a markdown string: + + md_page = build_md_page(page_info) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + + +def build_md_page(page_info): + """Given a PageInfo object, return markdown for the page. + + Args: + page_info: must be a `parser.FunctionPageInfo`, `parser.ClassPageInfo`, or + `parser.ModulePageInfo` + + Returns: + Markdown for the page + + Raises: + ValueError: if `page_info` is an instance of an unrecognized class + """ + if page_info.for_function(): + return _build_function_page(page_info) + + if page_info.for_class(): + return _build_class_page(page_info) + + if page_info.for_module(): + return _build_module_page(page_info) + + raise ValueError('Unknown Page Info Type: %s' % type(page_info)) + + +def _build_function_page(page_info): + """Given a FunctionPageInfo object Return the page as an md string.""" + parts = [_Metadata(page_info.full_name).build_html()] + parts.append('# %s\n\n' % page_info.full_name) + + if len(page_info.aliases) > 1: + parts.append('### Aliases:\n\n') + parts.extend('* `%s`\n' % name for name in page_info.aliases) + parts.append('\n') + + if page_info.signature is not None: + parts.append(_build_signature(page_info)) + + if page_info.defined_in: + parts.append('\n\n') + parts.append(str(page_info.defined_in)) + + parts.append(page_info.guides) + parts.append(page_info.doc.docstring) + parts.append(_build_function_details(page_info.doc.function_details)) + parts.append(_build_compatibility(page_info.doc.compatibility)) + + return ''.join(parts) + + +def _build_class_page(page_info): + """Given a ClassPageInfo object Return the page as an md string.""" + meta_data = _Metadata(page_info.full_name) + for item in itertools.chain( + page_info.classes, + page_info.properties, + page_info.methods, + page_info.other_members): + meta_data.append(item) + + parts = [meta_data.build_html()] + + parts.append('# {page_info.full_name}\n\n'.format(page_info=page_info)) + + parts.append('## Class `%s`\n\n' % page_info.full_name.split('.')[-1]) + if page_info.bases: + parts.append('Inherits From: ') + + link_template = '[`{short_name}`]({url})' + parts.append(', '.join( + link_template.format(**base.__dict__) for base in page_info.bases)) + + parts.append('\n\n') + + if len(page_info.aliases) > 1: + parts.append('### Aliases:\n\n') + parts.extend('* Class `%s`\n' % name for name in page_info.aliases) + parts.append('\n') + + if page_info.defined_in is not None: + parts.append('\n\n') + parts.append(str(page_info.defined_in)) + + parts.append(page_info.guides) + parts.append(page_info.doc.docstring) + parts.append(_build_function_details(page_info.doc.function_details)) + assert not page_info.doc.compatibility + parts.append('\n\n') + + if page_info.classes: + parts.append('## Child Classes\n') + + link_template = ('[`class {class_info.short_name}`]' + '({class_info.url})\n\n') + class_links = sorted( + link_template.format(class_info=class_info) + for class_info in page_info.classes) + + parts.extend(class_links) + + if page_info.properties: + parts.append('## Properties\n\n') + for prop_info in sorted(page_info.properties): + h3 = '

{short_name}

\n\n' + parts.append(h3.format(short_name=prop_info.short_name)) + + parts.append(prop_info.doc.docstring) + parts.append(_build_function_details(prop_info.doc.function_details)) + assert not prop_info.doc.compatibility + parts.append('\n\n') + + parts.append('\n\n') + + if page_info.methods: + parts.append('## Methods\n\n') + # Sort the methods list, but make sure constructors come first. + constructors = ['__init__', '__new__'] + inits = [method for method in page_info.methods + if method.short_name in constructors] + others = [method for method in page_info.methods + if method.short_name not in constructors] + + for method_info in sorted(inits) + sorted(others): + h3 = ('

' + '{short_name}' + '

\n\n') + parts.append(h3.format(**method_info.__dict__)) + + if method_info.signature is not None: + parts.append(_build_signature(method_info)) + + parts.append(method_info.doc.docstring) + parts.append(_build_function_details(method_info.doc.function_details)) + parts.append(_build_compatibility(method_info.doc.compatibility)) + parts.append('\n\n') + parts.append('\n\n') + + if page_info.other_members: + parts.append('## Class Members\n\n') + + # TODO(markdaoust): Document the value of the members, + # at least for basic types. + + h3 = '

{short_name}

\n\n' + others_member_headings = (h3.format(short_name=info.short_name) + for info in sorted(page_info.other_members)) + parts.extend(others_member_headings) + + return ''.join(parts) + + +def _build_module_page(page_info): + """Given a ClassPageInfo object Return the page as an md string.""" + meta_data = _Metadata(page_info.full_name) + + # Objects with their own pages are not added to the matadata list for the + # module, as the only thing on the module page is a link to the object's page. + for item in page_info.other_members: + meta_data.append(item) + + parts = [meta_data.build_html()] + + parts.append( + '# Module: {full_name}\n\n'.format(full_name=page_info.full_name)) + + if len(page_info.aliases) > 1: + parts.append('### Aliases:\n\n') + parts.extend('* Module `%s`\n' % name for name in page_info.aliases) + parts.append('\n') + + if page_info.defined_in is not None: + parts.append('\n\n') + parts.append(str(page_info.defined_in)) + + parts.append(page_info.doc.docstring) + parts.append('\n\n') + + if page_info.modules: + parts.append('## Modules\n\n') + template = '[`{short_name}`]({url}) module' + + for item in page_info.modules: + parts.append(template.format(**item.__dict__)) + + if item.doc.brief: + parts.append(': ' + item.doc.brief) + + parts.append('\n\n') + + if page_info.classes: + parts.append('## Classes\n\n') + template = '[`class {short_name}`]({url})' + + for item in page_info.classes: + parts.append(template.format(**item.__dict__)) + + if item.doc.brief: + parts.append(': ' + item.doc.brief) + + parts.append('\n\n') + + if page_info.functions: + parts.append('## Functions\n\n') + template = '[`{short_name}(...)`]({url})' + + for item in page_info.functions: + parts.append(template.format(**item.__dict__)) + + if item.doc.brief: + parts.append(': ' + item.doc.brief) + + parts.append('\n\n') + + if page_info.other_members: + # TODO(markdaoust): Document the value of the members, + # at least for basic types. + parts.append('## Other Members\n\n') + + for item in page_info.other_members: + parts.append('`{short_name}`\n\n'.format(**item.__dict__)) + + return ''.join(parts) + + +def _build_signature(obj_info): + """Returns a md code block showing the function signature.""" + # Special case tf.range, since it has an optional first argument + if obj_info.full_name == 'tf.range': + return ( + '``` python\n' + "range(limit, delta=1, dtype=None, name='range')\n" + "range(start, limit, delta=1, dtype=None, name='range')\n" + '```\n\n') + + signature_template = '\n'.join([ + '``` python', + '{name}({sig})', + '```\n\n']) + + if not obj_info.signature: + sig = '' + elif len(obj_info.signature) == 1: + sig = obj_info.signature[0] + else: + sig = ',\n'.join(' %s' % sig_item for sig_item in obj_info.signature) + sig = '\n'+sig+'\n' + + return signature_template.format(name=obj_info.short_name, sig=sig) + + +def _build_compatibility(compatibility): + """Return the compatibility section as an md string.""" + parts = [] + sorted_keys = sorted(compatibility.keys()) + for key in sorted_keys: + + value = compatibility[key] + parts.append('\n\n#### %s compatibility\n%s\n' % (key, value)) + + return ''.join(parts) + + +def _build_function_details(function_details): + """Return the function details section as an md string.""" + parts = [] + for detail in function_details: + sub = [] + sub.append('#### ' + detail.keyword + ':\n\n') + sub.append(detail.header) + for key, value in detail.items: + sub.append('* `%s`:%s' % (key, value)) + parts.append(''.join(sub)) + + return '\n'.join(parts) + + +class _Metadata(object): + """A class for building a page's Metadata block. + + Attributes: + name: The name of the page being described by the Metadata block. + """ + + def __init__(self, name): + """Creata a Metadata builder. + + Args: + name: The name of the page being described by the Metadata block. + """ + self.name = name + self._content = [] + + def append(self, item): + """Add an item from the page to the Metadata block. + + Args: + item: The parsed page section to add. + """ + self._content.append(item.short_name) + + def build_html(self): + """Return the Metadata block as an Html string.""" + schema = 'http://developers.google.com/ReferenceObject' + parts = ['
' % schema] + + parts.append('' % self.name) + for item in self._content: + parts.append('' % item) + + parts.extend(['
', '', '']) + + return '\n'.join(parts) diff --git a/tensorflow/tools/docs/py_guide_parser.py b/tensorflow/tools/docs/py_guide_parser.py new file mode 100644 index 00000000000..216353ecee3 --- /dev/null +++ b/tensorflow/tools/docs/py_guide_parser.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Library for operating on Python API Guide files.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re + + +def md_files_in_dir(py_guide_src_dir): + """Returns a list of filename (full_path, base) pairs for guide files.""" + all_in_dir = [(os.path.join(py_guide_src_dir, f), f) + for f in os.listdir(py_guide_src_dir)] + return [(full, f) for full, f in all_in_dir + if os.path.isfile(full) and f.endswith('.md')] + + +class PyGuideParser(object): + """Simple parsing of a guide .md file. + + Descendants can override the process_*() functions (called by process()) + to either record information from the guide, or call replace_line() + to affect the return value of process(). + """ + + def __init__(self): + self._lines = None + + def process(self, full_path): + """Read and process the file at `full_path`.""" + md_string = open(full_path).read() + self._lines = md_string.split('\n') + seen = set() + + in_blockquote = False + for i, line in enumerate(self._lines): + if '```' in line: + in_blockquote = not in_blockquote + + if not in_blockquote and line.startswith('# '): + self.process_title(i, line[2:]) + elif not in_blockquote and line.startswith('## '): + section_title = line.strip()[3:] + existing_tag = re.search(' {([^}]+)} *$', line) + if existing_tag: + tag = existing_tag.group(1) + else: + tag = re.sub('[^a-zA-Z0-9]+', '_', section_title) + if tag in seen: + suffix = 0 + while True: + candidate = '%s_%d' % (tag, suffix) + if candidate not in seen: + tag = candidate + break + seen.add(tag) + self.process_section(i, section_title, tag) + + elif in_blockquote: + self.process_in_blockquote(i, line) + else: + self.process_line(i, line) + + ret = '\n'.join(self._lines) + self._lines = None + return ret + + def replace_line(self, line_number, line): + """Replace the contents of line numbered `line_number` with `line`.""" + self._lines[line_number] = line + + def process_title(self, line_number, title): + pass + + def process_section(self, line_number, section_title, tag): + pass + + def process_in_blockquote(self, line_number, line): + pass + + def process_line(self, line_number, line): + pass diff --git a/tensorflow/tools/docs/py_guide_parser_test.py b/tensorflow/tools/docs/py_guide_parser_test.py new file mode 100644 index 00000000000..168b0535a94 --- /dev/null +++ b/tensorflow/tools/docs/py_guide_parser_test.py @@ -0,0 +1,84 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for py_guide_parser.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.platform import test +from tensorflow.tools.docs import py_guide_parser + + +class TestPyGuideParser(py_guide_parser.PyGuideParser): + + def __init__(self): + self.calls = [] + py_guide_parser.PyGuideParser.__init__(self) + + def process_title(self, line_number, title): + self.calls.append((line_number, 't', title)) + + def process_section(self, line_number, section_title, tag): + self.calls.append((line_number, 's', '%s : %s' % (section_title, tag))) + + def process_in_blockquote(self, line_number, line): + self.calls.append((line_number, 'b', line)) + self.replace_line(line_number, line + ' BQ') + + def process_line(self, line_number, line): + self.calls.append((line_number, 'l', line)) + + +class PyGuideParserTest(test.TestCase): + + def testBasics(self): + tmp = os.path.join(test.get_temp_dir(), 'py_guide_parser_test.md') + f = open(tmp, 'w') + f.write("""# a title +a line +## a section +```shell +in a blockquote +``` +out of blockquote +""") + f.close() + parser = TestPyGuideParser() + result = parser.process(tmp) + expected = """# a title +a line +## a section +```shell BQ +in a blockquote BQ +``` +out of blockquote +""" + self.assertEqual(expected, result) + expected = [(0, 't', 'a title'), + (1, 'l', 'a line'), + (2, 's', 'a section : a_section'), + (3, 'b', '```shell'), + (4, 'b', 'in a blockquote'), + (5, 'l', '```'), + (6, 'l', 'out of blockquote'), + (7, 'l', '')] + self.assertEqual(expected, parser.calls) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/tools/docs/tf-doxy_for_md-config b/tensorflow/tools/docs/tf-doxy_for_md-config deleted file mode 100644 index b7fd6e95076..00000000000 --- a/tensorflow/tools/docs/tf-doxy_for_md-config +++ /dev/null @@ -1,2280 +0,0 @@ -# Doxyfile 1.8.5 - -# This file describes the settings to be used by the documentation system -# doxygen (www.doxygen.org) for a project. -# -# All text after a double hash (##) is considered a comment and is placed in -# front of the TAG it is preceding. -# -# All text after a single hash (#) is considered a comment and will be ignored. -# The format is: -# TAG = value [value, ...] -# For lists, items can also be appended using: -# TAG += value [value, ...] -# Values that contain spaces should be placed between quotes (\" \"). - -#--------------------------------------------------------------------------- -# Project related configuration options -#--------------------------------------------------------------------------- - -# This tag specifies the encoding used for all characters in the config file -# that follow. The default is UTF-8 which is also the encoding used for all text -# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv -# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv -# for the list of possible encodings. -# The default value is: UTF-8. - -DOXYFILE_ENCODING = UTF-8 - -# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by -# double-quotes, unless you are using Doxywizard) that should identify the -# project for which the documentation is generated. This name is used in the -# title of most generated pages and in a few other places. -# The default value is: My Project. - -PROJECT_NAME = "TensorFlow" - -# The PROJECT_NUMBER tag can be used to enter a project or revision number. This -# could be handy for archiving the generated documentation or if some version -# control system is used. - -PROJECT_NUMBER = 0.0.0 - -# Using the PROJECT_BRIEF tag one can provide an optional one line description -# for a project that appears at the top of each page and should give viewer a -# quick idea about the purpose of the project. Keep the description short. - -PROJECT_BRIEF = - -# With the PROJECT_LOGO tag one can specify an logo or icon that is included in -# the documentation. The maximum height of the logo should not exceed 55 pixels -# and the maximum width should not exceed 200 pixels. Doxygen will copy the logo -# to the output directory. - -PROJECT_LOGO = - -# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path -# into which the generated documentation will be written. If a relative path is -# entered, it will be relative to the location where doxygen was started. If -# left blank the current directory will be used. - -OUTPUT_DIRECTORY = /tmp/tensorflow-docs/ - -# If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub- -# directories (in 2 levels) under the output directory of each output format and -# will distribute the generated files over these directories. Enabling this -# option can be useful when feeding doxygen a huge amount of source files, where -# putting all generated files in the same directory would otherwise causes -# performance problems for the file system. -# The default value is: NO. - -CREATE_SUBDIRS = NO - -# The OUTPUT_LANGUAGE tag is used to specify the language in which all -# documentation generated by doxygen is written. Doxygen will use this -# information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Brazilian, Catalan, Chinese, Chinese- -# Traditional, Croatian, Czech, Danish, Dutch, English, Esperanto, Farsi, -# Finnish, French, German, Greek, Hungarian, Italian, Japanese, Japanese-en, -# Korean, Korean-en, Latvian, Norwegian, Macedonian, Persian, Polish, -# Portuguese, Romanian, Russian, Serbian, Slovak, Slovene, Spanish, Swedish, -# Turkish, Ukrainian and Vietnamese. -# The default value is: English. - -OUTPUT_LANGUAGE = English - -# If the BRIEF_MEMBER_DESC tag is set to YES doxygen will include brief member -# descriptions after the members that are listed in the file and class -# documentation (similar to Javadoc). Set to NO to disable this. -# The default value is: YES. - -BRIEF_MEMBER_DESC = YES - -# If the REPEAT_BRIEF tag is set to YES doxygen will prepend the brief -# description of a member or function before the detailed description -# -# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the -# brief descriptions will be completely suppressed. -# The default value is: YES. - -REPEAT_BRIEF = YES - -# This tag implements a quasi-intelligent brief description abbreviator that is -# used to form the text in various listings. Each string in this list, if found -# as the leading text of the brief description, will be stripped from the text -# and the result, after processing the whole list, is used as the annotated -# text. Otherwise, the brief description is used as-is. If left blank, the -# following values are used ($name is automatically replaced with the name of -# the entity):The $name class, The $name widget, The $name file, is, provides, -# specifies, contains, represents, a, an and the. - -ABBREVIATE_BRIEF = - -# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then -# doxygen will generate a detailed section even if there is only a brief -# description. -# The default value is: NO. - -ALWAYS_DETAILED_SEC = NO - -# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all -# inherited members of a class in the documentation of that class as if those -# members were ordinary class members. Constructors, destructors and assignment -# operators of the base classes will not be shown. -# The default value is: NO. - -INLINE_INHERITED_MEMB = NO - -# If the FULL_PATH_NAMES tag is set to YES doxygen will prepend the full path -# before files name in the file list and in the header files. If set to NO the -# shortest path that makes the file name unique will be used -# The default value is: YES. - -FULL_PATH_NAMES = YES - -# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. -# Stripping is only done if one of the specified strings matches the left-hand -# part of the path. The tag can be used to show relative paths in the file list. -# If left blank the directory from which doxygen is run is used as the path to -# strip. -# -# Note that you can specify absolute paths here, but also relative paths, which -# will be relative from the directory where doxygen is started. -# This tag requires that the tag FULL_PATH_NAMES is set to YES. - -STRIP_FROM_PATH = - -# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the -# path mentioned in the documentation of a class, which tells the reader which -# header file to include in order to use a class. If left blank only the name of -# the header file containing the class definition is used. Otherwise one should -# specify the list of include paths that are normally passed to the compiler -# using the -I flag. - -STRIP_FROM_INC_PATH = - -# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but -# less readable) file names. This can be useful is your file systems doesn't -# support long names like on DOS, Mac, or CD-ROM. -# The default value is: NO. - -SHORT_NAMES = NO - -# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the -# first line (until the first dot) of a Javadoc-style comment as the brief -# description. If set to NO, the Javadoc-style will behave just like regular Qt- -# style comments (thus requiring an explicit @brief command for a brief -# description.) -# The default value is: NO. - -JAVADOC_AUTOBRIEF = NO - -# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first -# line (until the first dot) of a Qt-style comment as the brief description. If -# set to NO, the Qt-style will behave just like regular Qt-style comments (thus -# requiring an explicit \brief command for a brief description.) -# The default value is: NO. - -QT_AUTOBRIEF = NO - -# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a -# multi-line C++ special comment block (i.e. a block of //! or /// comments) as -# a brief description. This used to be the default behavior. The new default is -# to treat a multi-line C++ comment block as a detailed description. Set this -# tag to YES if you prefer the old behavior instead. -# -# Note that setting this tag to YES also means that rational rose comments are -# not recognized any more. -# The default value is: NO. - -MULTILINE_CPP_IS_BRIEF = NO - -# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the -# documentation from any documented member that it re-implements. -# The default value is: YES. - -INHERIT_DOCS = YES - -# If the SEPARATE_MEMBER_PAGES tag is set to YES, then doxygen will produce a -# new page for each member. If set to NO, the documentation of a member will be -# part of the file/class/namespace that contains it. -# The default value is: NO. - -SEPARATE_MEMBER_PAGES = NO - -# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen -# uses this value to replace tabs by spaces in code fragments. -# Minimum value: 1, maximum value: 16, default value: 4. - -TAB_SIZE = 4 - -# This tag can be used to specify a number of aliases that act as commands in -# the documentation. An alias has the form: -# name=value -# For example adding -# "sideeffect=@par Side Effects:\n" -# will allow you to put the command \sideeffect (or @sideeffect) in the -# documentation, which will result in a user-defined paragraph with heading -# "Side Effects:". You can put \n's in the value part of an alias to insert -# newlines. - -ALIASES = - -# This tag can be used to specify a number of word-keyword mappings (TCL only). -# A mapping has the form "name=value". For example adding "class=itcl::class" -# will allow you to use the command class in the itcl::class meaning. - -TCL_SUBST = - -# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources -# only. Doxygen will then generate output that is more tailored for C. For -# instance, some of the names that are used will be different. The list of all -# members will be omitted, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_FOR_C = NO - -# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or -# Python sources only. Doxygen will then generate output that is more tailored -# for that language. For instance, namespaces will be presented as packages, -# qualified scopes will look different, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_JAVA = NO - -# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran -# sources. Doxygen will then generate output that is tailored for Fortran. -# The default value is: NO. - -OPTIMIZE_FOR_FORTRAN = NO - -# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL -# sources. Doxygen will then generate output that is tailored for VHDL. -# The default value is: NO. - -OPTIMIZE_OUTPUT_VHDL = NO - -# Doxygen selects the parser to use depending on the extension of the files it -# parses. With this tag you can assign which parser to use for a given -# extension. Doxygen has a built-in mapping, but you can override or extend it -# using this tag. The format is ext=language, where ext is a file extension, and -# language is one of the parsers supported by doxygen: IDL, Java, Javascript, -# C#, C, C++, D, PHP, Objective-C, Python, Fortran, VHDL. For instance to make -# doxygen treat .inc files as Fortran files (default is PHP), and .f files as C -# (default is Fortran), use: inc=Fortran f=C. -# -# Note For files without extension you can use no_extension as a placeholder. -# -# Note that for custom extensions you also need to set FILE_PATTERNS otherwise -# the files are not read by doxygen. - -EXTENSION_MAPPING = - -# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments -# according to the Markdown format, which allows for more readable -# documentation. See http://daringfireball.net/projects/markdown/ for details. -# The output of markdown processing is further processed by doxygen, so you can -# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in -# case of backward compatibilities issues. -# The default value is: YES. - -MARKDOWN_SUPPORT = YES - -# When enabled doxygen tries to link words that correspond to documented -# classes, or namespaces to their corresponding documentation. Such a link can -# be prevented in individual cases by putting a % sign in front of the word -# or globally by setting AUTOLINK_SUPPORT to NO. -# The default value is: YES. - -AUTOLINK_SUPPORT = YES - -# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want -# to include (a tag file for) the STL sources as input, then you should set this -# tag to YES in order to let doxygen match functions declarations and -# definitions whose arguments contain STL classes (e.g. func(std::string); -# versus func(std::string) {}). This also make the inheritance and collaboration -# diagrams that involve STL classes more complete and accurate. -# The default value is: NO. - -BUILTIN_STL_SUPPORT = NO - -# If you use Microsoft's C++/CLI language, you should set this option to YES to -# enable parsing support. -# The default value is: NO. - -CPP_CLI_SUPPORT = NO - -# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: -# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen -# will parse them like normal C++ but will assume all classes use public instead -# of private inheritance when no explicit protection keyword is present. -# The default value is: NO. - -SIP_SUPPORT = NO - -# For Microsoft's IDL there are propget and propput attributes to indicate -# getter and setter methods for a property. Setting this option to YES will make -# doxygen to replace the get and set methods by a property in the documentation. -# This will only work if the methods are indeed getting or setting a simple -# type. If this is not the case, or you want to show the methods anyway, you -# should set this option to NO. -# The default value is: YES. - -IDL_PROPERTY_SUPPORT = YES - -# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC -# tag is set to YES, then doxygen will reuse the documentation of the first -# member in the group (if any) for the other members of the group. By default -# all members of a group must be documented explicitly. -# The default value is: NO. - -DISTRIBUTE_GROUP_DOC = NO - -# Set the SUBGROUPING tag to YES to allow class member groups of the same type -# (for instance a group of public functions) to be put as a subgroup of that -# type (e.g. under the Public Functions section). Set it to NO to prevent -# subgrouping. Alternatively, this can be done per class using the -# \nosubgrouping command. -# The default value is: YES. - -SUBGROUPING = YES - -# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions -# are shown inside the group in which they are included (e.g. using \ingroup) -# instead of on a separate page (for HTML and Man pages) or section (for LaTeX -# and RTF). -# -# Note that this feature does not work in combination with -# SEPARATE_MEMBER_PAGES. -# The default value is: NO. - -INLINE_GROUPED_CLASSES = NO - -# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions -# with only public data fields or simple typedef fields will be shown inline in -# the documentation of the scope in which they are defined (i.e. file, -# namespace, or group documentation), provided this scope is documented. If set -# to NO, structs, classes, and unions are shown on a separate page (for HTML and -# Man pages) or section (for LaTeX and RTF). -# The default value is: NO. - -INLINE_SIMPLE_STRUCTS = NO - -# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or -# enum is documented as struct, union, or enum with the name of the typedef. So -# typedef struct TypeS {} TypeT, will appear in the documentation as a struct -# with name TypeT. When disabled the typedef will appear as a member of a file, -# namespace, or class. And the struct will be named TypeS. This can typically be -# useful for C code in case the coding convention dictates that all compound -# types are typedef'ed and only the typedef is referenced, never the tag name. -# The default value is: NO. - -TYPEDEF_HIDES_STRUCT = NO - -# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This -# cache is used to resolve symbols given their name and scope. Since this can be -# an expensive process and often the same symbol appears multiple times in the -# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small -# doxygen will become slower. If the cache is too large, memory is wasted. The -# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range -# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 -# symbols. At the end of a run doxygen will report the cache usage and suggest -# the optimal cache size from a speed point of view. -# Minimum value: 0, maximum value: 9, default value: 0. - -LOOKUP_CACHE_SIZE = 0 - -#--------------------------------------------------------------------------- -# Build related configuration options -#--------------------------------------------------------------------------- - -# If the EXTRACT_ALL tag is set to YES doxygen will assume all entities in -# documentation are documented, even if no documentation was available. Private -# class members and static file members will be hidden unless the -# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. -# Note: This will also disable the warnings about undocumented members that are -# normally produced when WARNINGS is set to YES. -# The default value is: NO. - -EXTRACT_ALL = NO - -# If the EXTRACT_PRIVATE tag is set to YES all private members of a class will -# be included in the documentation. -# The default value is: NO. - -EXTRACT_PRIVATE = NO - -# If the EXTRACT_PACKAGE tag is set to YES all members with package or internal -# scope will be included in the documentation. -# The default value is: NO. - -EXTRACT_PACKAGE = NO - -# If the EXTRACT_STATIC tag is set to YES all static members of a file will be -# included in the documentation. -# The default value is: NO. - -EXTRACT_STATIC = YES - -# If the EXTRACT_LOCAL_CLASSES tag is set to YES classes (and structs) defined -# locally in source files will be included in the documentation. If set to NO -# only classes defined in header files are included. Does not have any effect -# for Java sources. -# The default value is: YES. - -EXTRACT_LOCAL_CLASSES = YES - -# This flag is only useful for Objective-C code. When set to YES local methods, -# which are defined in the implementation section but not in the interface are -# included in the documentation. If set to NO only methods in the interface are -# included. -# The default value is: NO. - -EXTRACT_LOCAL_METHODS = NO - -# If this flag is set to YES, the members of anonymous namespaces will be -# extracted and appear in the documentation as a namespace called -# 'anonymous_namespace{file}', where file will be replaced with the base name of -# the file that contains the anonymous namespace. By default anonymous namespace -# are hidden. -# The default value is: NO. - -EXTRACT_ANON_NSPACES = NO - -# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all -# undocumented members inside documented classes or files. If set to NO these -# members will be included in the various overviews, but no documentation -# section is generated. This option has no effect if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_MEMBERS = NO - -# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all -# undocumented classes that are normally visible in the class hierarchy. If set -# to NO these classes will be included in the various overviews. This option has -# no effect if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_CLASSES = NO - -# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend -# (class|struct|union) declarations. If set to NO these declarations will be -# included in the documentation. -# The default value is: NO. - -HIDE_FRIEND_COMPOUNDS = NO - -# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any -# documentation blocks found inside the body of a function. If set to NO these -# blocks will be appended to the function's detailed documentation block. -# The default value is: NO. - -HIDE_IN_BODY_DOCS = NO - -# The INTERNAL_DOCS tag determines if documentation that is typed after a -# \internal command is included. If the tag is set to NO then the documentation -# will be excluded. Set it to YES to include the internal documentation. -# The default value is: NO. - -INTERNAL_DOCS = NO - -# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file -# names in lower-case letters. If set to YES upper-case letters are also -# allowed. This is useful if you have classes or files whose names only differ -# in case and if your file system supports case sensitive file names. Windows -# and Mac users are advised to set this option to NO. -# The default value is: system dependent. - -CASE_SENSE_NAMES = YES - -# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with -# their full class and namespace scopes in the documentation. If set to YES the -# scope will be hidden. -# The default value is: NO. - -HIDE_SCOPE_NAMES = NO - -# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of -# the files that are included by a file in the documentation of that file. -# The default value is: YES. - -SHOW_INCLUDE_FILES = YES - -# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include -# files with double quotes in the documentation rather than with sharp brackets. -# The default value is: NO. - -FORCE_LOCAL_INCLUDES = NO - -# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the -# documentation for inline members. -# The default value is: YES. - -INLINE_INFO = YES - -# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the -# (detailed) documentation of file and class members alphabetically by member -# name. If set to NO the members will appear in declaration order. -# The default value is: YES. - -SORT_MEMBER_DOCS = YES - -# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief -# descriptions of file, namespace and class members alphabetically by member -# name. If set to NO the members will appear in declaration order. -# The default value is: NO. - -SORT_BRIEF_DOCS = NO - -# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the -# (brief and detailed) documentation of class members so that constructors and -# destructors are listed first. If set to NO the constructors will appear in the -# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. -# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief -# member documentation. -# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting -# detailed member documentation. -# The default value is: NO. - -SORT_MEMBERS_CTORS_1ST = NO - -# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy -# of group names into alphabetical order. If set to NO the group names will -# appear in their defined order. -# The default value is: NO. - -SORT_GROUP_NAMES = NO - -# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by -# fully-qualified names, including namespaces. If set to NO, the class list will -# be sorted only by class name, not including the namespace part. -# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. -# Note: This option applies only to the class list, not to the alphabetical -# list. -# The default value is: NO. - -SORT_BY_SCOPE_NAME = NO - -# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper -# type resolution of all parameters of a function it will reject a match between -# the prototype and the implementation of a member function even if there is -# only one candidate or it is obvious which candidate to choose by doing a -# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still -# accept a match between prototype and implementation in such cases. -# The default value is: NO. - -STRICT_PROTO_MATCHING = NO - -# The GENERATE_TODOLIST tag can be used to enable ( YES) or disable ( NO) the -# todo list. This list is created by putting \todo commands in the -# documentation. -# The default value is: YES. - -GENERATE_TODOLIST = YES - -# The GENERATE_TESTLIST tag can be used to enable ( YES) or disable ( NO) the -# test list. This list is created by putting \test commands in the -# documentation. -# The default value is: YES. - -GENERATE_TESTLIST = YES - -# The GENERATE_BUGLIST tag can be used to enable ( YES) or disable ( NO) the bug -# list. This list is created by putting \bug commands in the documentation. -# The default value is: YES. - -GENERATE_BUGLIST = YES - -# The GENERATE_DEPRECATEDLIST tag can be used to enable ( YES) or disable ( NO) -# the deprecated list. This list is created by putting \deprecated commands in -# the documentation. -# The default value is: YES. - -GENERATE_DEPRECATEDLIST= YES - -# The ENABLED_SECTIONS tag can be used to enable conditional documentation -# sections, marked by \if ... \endif and \cond -# ... \endcond blocks. - -ENABLED_SECTIONS = - -# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the -# initial value of a variable or macro / define can have for it to appear in the -# documentation. If the initializer consists of more lines than specified here -# it will be hidden. Use a value of 0 to hide initializers completely. The -# appearance of the value of individual variables and macros / defines can be -# controlled using \showinitializer or \hideinitializer command in the -# documentation regardless of this setting. -# Minimum value: 0, maximum value: 10000, default value: 30. - -MAX_INITIALIZER_LINES = 30 - -# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at -# the bottom of the documentation of classes and structs. If set to YES the list -# will mention the files that were used to generate the documentation. -# The default value is: YES. - -SHOW_USED_FILES = YES - -# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This -# will remove the Files entry from the Quick Index and from the Folder Tree View -# (if specified). -# The default value is: YES. - -SHOW_FILES = YES - -# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces -# page. This will remove the Namespaces entry from the Quick Index and from the -# Folder Tree View (if specified). -# The default value is: YES. - -SHOW_NAMESPACES = YES - -# The FILE_VERSION_FILTER tag can be used to specify a program or script that -# doxygen should invoke to get the current version for each file (typically from -# the version control system). Doxygen will invoke the program by executing (via -# popen()) the command input-file, where command is the value of the -# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided -# by doxygen. Whatever the program writes to standard output is used as the file -# version. For an example see the documentation. - -FILE_VERSION_FILTER = - -# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed -# by doxygen. The layout file controls the global structure of the generated -# output files in an output format independent way. To create the layout file -# that represents doxygen's defaults, run doxygen with the -l option. You can -# optionally specify a file name after the option, if omitted DoxygenLayout.xml -# will be used as the name of the layout file. -# -# Note that if you run doxygen from a directory containing a file called -# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE -# tag is left empty. - -LAYOUT_FILE = - -# The CITE_BIB_FILES tag can be used to specify one or more bib files containing -# the reference definitions. This must be a list of .bib files. The .bib -# extension is automatically appended if omitted. This requires the bibtex tool -# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. -# For LaTeX the style of the bibliography can be controlled using -# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the -# search path. Do not use file names with spaces, bibtex cannot handle them. See -# also \cite for info how to create references. - -CITE_BIB_FILES = - -#--------------------------------------------------------------------------- -# Configuration options related to warning and progress messages -#--------------------------------------------------------------------------- - -# The QUIET tag can be used to turn on/off the messages that are generated to -# standard output by doxygen. If QUIET is set to YES this implies that the -# messages are off. -# The default value is: NO. - -QUIET = NO - -# The WARNINGS tag can be used to turn on/off the warning messages that are -# generated to standard error ( stderr) by doxygen. If WARNINGS is set to YES -# this implies that the warnings are on. -# -# Tip: Turn warnings on while writing the documentation. -# The default value is: YES. - -WARNINGS = YES - -# If the WARN_IF_UNDOCUMENTED tag is set to YES, then doxygen will generate -# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag -# will automatically be disabled. -# The default value is: YES. - -WARN_IF_UNDOCUMENTED = YES - -# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for -# potential errors in the documentation, such as not documenting some parameters -# in a documented function, or documenting parameters that don't exist or using -# markup commands wrongly. -# The default value is: YES. - -WARN_IF_DOC_ERROR = YES - -# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that -# are documented, but have no documentation for their parameters or return -# value. If set to NO doxygen will only warn about wrong or incomplete parameter -# documentation, but not about the absence of documentation. -# The default value is: NO. - -WARN_NO_PARAMDOC = NO - -# The WARN_FORMAT tag determines the format of the warning messages that doxygen -# can produce. The string should contain the $file, $line, and $text tags, which -# will be replaced by the file and line number from which the warning originated -# and the warning text. Optionally the format may contain $version, which will -# be replaced by the version of the file (if it could be obtained via -# FILE_VERSION_FILTER) -# The default value is: $file:$line: $text. - -WARN_FORMAT = "$file:$line: $text" - -# The WARN_LOGFILE tag can be used to specify a file to which warning and error -# messages should be written. If left blank the output is written to standard -# error (stderr). - -WARN_LOGFILE = - -#--------------------------------------------------------------------------- -# Configuration options related to the input files -#--------------------------------------------------------------------------- - -# The INPUT tag is used to specify the files and/or directories that contain -# documented source files. You may enter file names like myfile.cpp or -# directories like /usr/src/myproject. Separate the files or directories with -# spaces. -# Note: If this tag is empty the current directory is searched. - -INPUT = core/framework core/lib/core core/platform core/public - -# This tag can be used to specify the character encoding of the source files -# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses -# libiconv (or the iconv built into libc) for the transcoding. See the libiconv -# documentation (see: http://www.gnu.org/software/libiconv) for the list of -# possible encodings. -# The default value is: UTF-8. - -INPUT_ENCODING = UTF-8 - -# If the value of the INPUT tag contains directories, you can use the -# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and -# *.h) to filter out the source-files in the directories. If left blank the -# following patterns are tested:*.c, *.cc, *.cxx, *.cpp, *.c++, *.java, *.ii, -# *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, *.hh, *.hxx, *.hpp, -# *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, *.m, *.markdown, -# *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, *.vhdl, *.ucf, -# *.qsf, *.as and *.js. - -FILE_PATTERNS = - -# The RECURSIVE tag can be used to specify whether or not subdirectories should -# be searched for input files as well. -# The default value is: NO. - -RECURSIVE = NO - -# The EXCLUDE tag can be used to specify files and/or directories that should be -# excluded from the INPUT source files. This way you can easily exclude a -# subdirectory from a directory tree whose root is specified with the INPUT tag. -# -# Note that relative paths are relative to the directory from which doxygen is -# run. - -EXCLUDE = - -# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or -# directories that are symbolic links (a Unix file system feature) are excluded -# from the input. -# The default value is: NO. - -EXCLUDE_SYMLINKS = NO - -# If the value of the INPUT tag contains directories, you can use the -# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude -# certain files from those directories. -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories for example use the pattern */test/* - -EXCLUDE_PATTERNS = - -# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names -# (namespaces, classes, functions, etc.) that should be excluded from the -# output. The symbol name can be a fully qualified name, a word, or if the -# wildcard * is used, a substring. Examples: ANamespace, AClass, -# AClass::ANamespace, ANamespace::*Test -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories use the pattern */test/* - -EXCLUDE_SYMBOLS = - -# The EXAMPLE_PATH tag can be used to specify one or more files or directories -# that contain example code fragments that are included (see the \include -# command). - -EXAMPLE_PATH = - -# If the value of the EXAMPLE_PATH tag contains directories, you can use the -# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and -# *.h) to filter out the source-files in the directories. If left blank all -# files are included. - -EXAMPLE_PATTERNS = - -# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be -# searched for input files to be used with the \include or \dontinclude commands -# irrespective of the value of the RECURSIVE tag. -# The default value is: NO. - -EXAMPLE_RECURSIVE = NO - -# The IMAGE_PATH tag can be used to specify one or more files or directories -# that contain images that are to be included in the documentation (see the -# \image command). - -IMAGE_PATH = - -# The INPUT_FILTER tag can be used to specify a program that doxygen should -# invoke to filter for each input file. Doxygen will invoke the filter program -# by executing (via popen()) the command: -# -# -# -# where is the value of the INPUT_FILTER tag, and is the -# name of an input file. Doxygen will then use the output that the filter -# program writes to standard output. If FILTER_PATTERNS is specified, this tag -# will be ignored. -# -# Note that the filter must not add or remove lines; it is applied before the -# code is scanned, but not when the output code is generated. If lines are added -# or removed, the anchors will not be placed correctly. - -INPUT_FILTER = - -# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern -# basis. Doxygen will compare the file name with each pattern and apply the -# filter if there is a match. The filters are a list of the form: pattern=filter -# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how -# filters are used. If the FILTER_PATTERNS tag is empty or if none of the -# patterns match the file name, INPUT_FILTER is applied. - -FILTER_PATTERNS = - -# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using -# INPUT_FILTER ) will also be used to filter the input files that are used for -# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). -# The default value is: NO. - -FILTER_SOURCE_FILES = NO - -# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file -# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and -# it is also possible to disable source filtering for a specific pattern using -# *.ext= (so without naming a filter). -# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. - -FILTER_SOURCE_PATTERNS = - -# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that -# is part of the input, its contents will be placed on the main page -# (index.html). This can be useful if you have a project on for instance GitHub -# and want to reuse the introduction page also for the doxygen output. - -USE_MDFILE_AS_MAINPAGE = - -#--------------------------------------------------------------------------- -# Configuration options related to source browsing -#--------------------------------------------------------------------------- - -# If the SOURCE_BROWSER tag is set to YES then a list of source files will be -# generated. Documented entities will be cross-referenced with these sources. -# -# Note: To get rid of all source code in the generated output, make sure that -# also VERBATIM_HEADERS is set to NO. -# The default value is: NO. - -SOURCE_BROWSER = NO - -# Setting the INLINE_SOURCES tag to YES will include the body of functions, -# classes and enums directly into the documentation. -# The default value is: NO. - -INLINE_SOURCES = NO - -# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any -# special comment blocks from generated source code fragments. Normal C, C++ and -# Fortran comments will always remain visible. -# The default value is: YES. - -STRIP_CODE_COMMENTS = YES - -# If the REFERENCED_BY_RELATION tag is set to YES then for each documented -# function all documented functions referencing it will be listed. -# The default value is: NO. - -REFERENCED_BY_RELATION = NO - -# If the REFERENCES_RELATION tag is set to YES then for each documented function -# all documented entities called/used by that function will be listed. -# The default value is: NO. - -REFERENCES_RELATION = NO - -# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set -# to YES, then the hyperlinks from functions in REFERENCES_RELATION and -# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will -# link to the documentation. -# The default value is: YES. - -REFERENCES_LINK_SOURCE = YES - -# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the -# source code will show a tooltip with additional information such as prototype, -# brief description and links to the definition and documentation. Since this -# will make the HTML file larger and loading of large files a bit slower, you -# can opt to disable this feature. -# The default value is: YES. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -SOURCE_TOOLTIPS = YES - -# If the USE_HTAGS tag is set to YES then the references to source code will -# point to the HTML generated by the htags(1) tool instead of doxygen built-in -# source browser. The htags tool is part of GNU's global source tagging system -# (see http://www.gnu.org/software/global/global.html). You will need version -# 4.8.6 or higher. -# -# To use it do the following: -# - Install the latest version of global -# - Enable SOURCE_BROWSER and USE_HTAGS in the config file -# - Make sure the INPUT points to the root of the source tree -# - Run doxygen as normal -# -# Doxygen will invoke htags (and that will in turn invoke gtags), so these -# tools must be available from the command line (i.e. in the search path). -# -# The result: instead of the source browser generated by doxygen, the links to -# source code will now point to the output of htags. -# The default value is: NO. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -USE_HTAGS = NO - -# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a -# verbatim copy of the header file for each class for which an include is -# specified. Set to NO to disable this. -# See also: Section \class. -# The default value is: YES. - -VERBATIM_HEADERS = YES - -#--------------------------------------------------------------------------- -# Configuration options related to the alphabetical class index -#--------------------------------------------------------------------------- - -# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all -# compounds will be generated. Enable this if the project contains a lot of -# classes, structs, unions or interfaces. -# The default value is: YES. - -ALPHABETICAL_INDEX = YES - -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - -# In case all classes in a project start with a common prefix, all classes will -# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag -# can be used to specify a prefix (or a list of prefixes) that should be ignored -# while generating the index headers. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -IGNORE_PREFIX = - -#--------------------------------------------------------------------------- -# Configuration options related to the HTML output -#--------------------------------------------------------------------------- - -# If the GENERATE_HTML tag is set to YES doxygen will generate HTML output -# The default value is: YES. - -GENERATE_HTML = NO - -# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a -# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of -# it. -# The default directory is: html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_OUTPUT = html - -# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each -# generated HTML page (for example: .htm, .php, .asp). -# The default value is: .html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FILE_EXTENSION = .html - -# The HTML_HEADER tag can be used to specify a user-defined HTML header file for -# each generated HTML page. If the tag is left blank doxygen will generate a -# standard header. -# -# To get valid HTML the header file that includes any scripts and style sheets -# that doxygen needs, which is dependent on the configuration options used (e.g. -# the setting GENERATE_TREEVIEW). It is highly recommended to start with a -# default header using -# doxygen -w html new_header.html new_footer.html new_stylesheet.css -# YourConfigFile -# and then modify the file new_header.html. See also section "Doxygen usage" -# for information on how to generate the default header that doxygen normally -# uses. -# Note: The header is subject to change so you typically have to regenerate the -# default header when upgrading to a newer version of doxygen. For a description -# of the possible markers and block names see the documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_HEADER = - -# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each -# generated HTML page. If the tag is left blank doxygen will generate a standard -# footer. See HTML_HEADER for more information on how to generate a default -# footer and what special commands can be used inside the footer. See also -# section "Doxygen usage" for information on how to generate the default footer -# that doxygen normally uses. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FOOTER = - -# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style -# sheet that is used by each HTML page. It can be used to fine-tune the look of -# the HTML output. If left blank doxygen will generate a default style sheet. -# See also section "Doxygen usage" for information on how to generate the style -# sheet that doxygen normally uses. -# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as -# it is more robust and this tag (HTML_STYLESHEET) will in the future become -# obsolete. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_STYLESHEET = - -# The HTML_EXTRA_STYLESHEET tag can be used to specify an additional user- -# defined cascading style sheet that is included after the standard style sheets -# created by doxygen. Using this option one can overrule certain style aspects. -# This is preferred over using HTML_STYLESHEET since it does not replace the -# standard style sheet and is therefor more robust against future updates. -# Doxygen will copy the style sheet file to the output directory. For an example -# see the documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_STYLESHEET = - -# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or -# other source files which should be copied to the HTML output directory. Note -# that these files will be copied to the base HTML output directory. Use the -# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these -# files. In the HTML_STYLESHEET file, use the file name only. Also note that the -# files will be copied as-is; there are no commands or markers available. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_FILES = - -# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen -# will adjust the colors in the stylesheet and background images according to -# this color. Hue is specified as an angle on a colorwheel, see -# http://en.wikipedia.org/wiki/Hue for more information. For instance the value -# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 -# purple, and 360 is red again. -# Minimum value: 0, maximum value: 359, default value: 220. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_HUE = 220 - -# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors -# in the HTML output. For a value of 0 the output will use grayscales only. A -# value of 255 will produce the most vivid colors. -# Minimum value: 0, maximum value: 255, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_SAT = 100 - -# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the -# luminance component of the colors in the HTML output. Values below 100 -# gradually make the output lighter, whereas values above 100 make the output -# darker. The value divided by 100 is the actual gamma applied, so 80 represents -# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not -# change the gamma. -# Minimum value: 40, maximum value: 240, default value: 80. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_GAMMA = 80 - -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to NO can help when comparing the output of multiple runs. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_TIMESTAMP = NO - -# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML -# documentation will contain sections that can be hidden and shown after the -# page has loaded. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_DYNAMIC_SECTIONS = NO - -# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries -# shown in the various tree structured indices initially; the user can expand -# and collapse entries dynamically later on. Doxygen will expand the tree to -# such a level that at most the specified number of entries are visible (unless -# a fully collapsed tree already exceeds this amount). So setting the number of -# entries 1 will produce a full collapsed tree by default. 0 is a special value -# representing an infinite number of entries and will result in a full expanded -# tree by default. -# Minimum value: 0, maximum value: 9999, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_INDEX_NUM_ENTRIES = 100 - -# If the GENERATE_DOCSET tag is set to YES, additional index files will be -# generated that can be used as input for Apple's Xcode 3 integrated development -# environment (see: http://developer.apple.com/tools/xcode/), introduced with -# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a -# Makefile in the HTML output directory. Running make will produce the docset in -# that directory and running make install will install the docset in -# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at -# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html -# for more information. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_DOCSET = NO - -# This tag determines the name of the docset feed. A documentation feed provides -# an umbrella under which multiple documentation sets from a single provider -# (such as a company or product suite) can be grouped. -# The default value is: Doxygen generated docs. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_FEEDNAME = "Doxygen generated docs" - -# This tag specifies a string that should uniquely identify the documentation -# set bundle. This should be a reverse domain-name style string, e.g. -# com.mycompany.MyDocSet. Doxygen will append .docset to the name. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_BUNDLE_ID = org.doxygen.Project - -# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify -# the documentation publisher. This should be a reverse domain-name style -# string, e.g. com.mycompany.MyDocSet.documentation. -# The default value is: org.doxygen.Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_ID = org.doxygen.Publisher - -# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. -# The default value is: Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_NAME = Publisher - -# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three -# additional HTML index files: index.hhp, index.hhc, and index.hhk. The -# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop -# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on -# Windows. -# -# The HTML Help Workshop contains a compiler that can convert all HTML output -# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML -# files are now used as the Windows 98 help format, and will replace the old -# Windows help format (.hlp) on all Windows platforms in the future. Compressed -# HTML files also contain an index, a table of contents, and you can search for -# words in the documentation. The HTML workshop also contains a viewer for -# compressed HTML files. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_HTMLHELP = NO - -# The CHM_FILE tag can be used to specify the file name of the resulting .chm -# file. You can add a path in front of the file if the result should not be -# written to the html output directory. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_FILE = - -# The HHC_LOCATION tag can be used to specify the location (absolute path -# including file name) of the HTML help compiler ( hhc.exe). If non-empty -# doxygen will try to run the HTML help compiler on the generated index.hhp. -# The file has to be specified with full path. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -HHC_LOCATION = - -# The GENERATE_CHI flag controls if a separate .chi index file is generated ( -# YES) or that it should be included in the master .chm file ( NO). -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -GENERATE_CHI = NO - -# The CHM_INDEX_ENCODING is used to encode HtmlHelp index ( hhk), content ( hhc) -# and project file content. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_INDEX_ENCODING = - -# The BINARY_TOC flag controls whether a binary table of contents is generated ( -# YES) or a normal table of contents ( NO) in the .chm file. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -BINARY_TOC = NO - -# The TOC_EXPAND flag can be set to YES to add extra items for group members to -# the table of contents of the HTML help documentation and to the tree view. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -TOC_EXPAND = NO - -# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and -# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that -# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help -# (.qch) of the generated HTML documentation. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_QHP = NO - -# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify -# the file name of the resulting .qch file. The path specified is relative to -# the HTML output folder. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QCH_FILE = - -# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help -# Project output. For more information please see Qt Help Project / Namespace -# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_NAMESPACE = org.doxygen.Project - -# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt -# Help Project output. For more information please see Qt Help Project / Virtual -# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- -# folders). -# The default value is: doc. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_VIRTUAL_FOLDER = doc - -# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom -# filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_NAME = - -# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the -# custom filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_ATTRS = - -# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this -# project's filter section matches. Qt Help Project / Filter Attributes (see: -# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_SECT_FILTER_ATTRS = - -# The QHG_LOCATION tag can be used to specify the location of Qt's -# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the -# generated .qhp file. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHG_LOCATION = - -# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be -# generated, together with the HTML files, they form an Eclipse help plugin. To -# install this plugin and make it available under the help contents menu in -# Eclipse, the contents of the directory containing the HTML and XML files needs -# to be copied into the plugins directory of eclipse. The name of the directory -# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. -# After copying Eclipse needs to be restarted before the help appears. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_ECLIPSEHELP = NO - -# A unique identifier for the Eclipse help plugin. When installing the plugin -# the directory name containing the HTML and XML files should also have this -# name. Each documentation set should have its own identifier. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. - -ECLIPSE_DOC_ID = org.doxygen.Project - -# If you want full control over the layout of the generated HTML pages it might -# be necessary to disable the index and replace it with your own. The -# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top -# of each HTML page. A value of NO enables the index and the value YES disables -# it. Since the tabs in the index contain the same information as the navigation -# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -DISABLE_INDEX = NO - -# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index -# structure should be generated to display hierarchical information. If the tag -# value is set to YES, a side panel will be generated containing a tree-like -# index structure (just like the one that is generated for HTML Help). For this -# to work a browser that supports JavaScript, DHTML, CSS and frames is required -# (i.e. any modern browser). Windows users are probably better off using the -# HTML help feature. Via custom stylesheets (see HTML_EXTRA_STYLESHEET) one can -# further fine-tune the look of the index. As an example, the default style -# sheet generated by doxygen has an example that shows how to put an image at -# the root of the tree instead of the PROJECT_NAME. Since the tree basically has -# the same information as the tab index, you could consider setting -# DISABLE_INDEX to YES when enabling this option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_TREEVIEW = NO - -# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that -# doxygen will group on one line in the generated HTML documentation. -# -# Note that a value of 0 will completely suppress the enum values from appearing -# in the overview section. -# Minimum value: 0, maximum value: 20, default value: 4. -# This tag requires that the tag GENERATE_HTML is set to YES. - -ENUM_VALUES_PER_LINE = 4 - -# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used -# to set the initial width (in pixels) of the frame in which the tree is shown. -# Minimum value: 0, maximum value: 1500, default value: 250. -# This tag requires that the tag GENERATE_HTML is set to YES. - -TREEVIEW_WIDTH = 250 - -# When the EXT_LINKS_IN_WINDOW option is set to YES doxygen will open links to -# external symbols imported via tag files in a separate window. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -EXT_LINKS_IN_WINDOW = NO - -# Use this tag to change the font size of LaTeX formulas included as images in -# the HTML documentation. When you change the font size after a successful -# doxygen run you need to manually remove any form_*.png images from the HTML -# output directory to force them to be regenerated. -# Minimum value: 8, maximum value: 50, default value: 10. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_FONTSIZE = 10 - -# Use the FORMULA_TRANPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_TRANSPARENT = YES - -# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see -# http://www.mathjax.org) which uses client side Javascript for the rendering -# instead of using prerendered bitmaps. Use this if you do not have LaTeX -# installed or if you want to formulas look prettier in the HTML output. When -# enabled you may also need to install MathJax separately and configure the path -# to it using the MATHJAX_RELPATH option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -USE_MATHJAX = NO - -# When MathJax is enabled you can set the default output format to be used for -# the MathJax output. See the MathJax site (see: -# http://docs.mathjax.org/en/latest/output.html) for more details. -# Possible values are: HTML-CSS (which is slower, but has the best -# compatibility), NativeMML (i.e. MathML) and SVG. -# The default value is: HTML-CSS. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_FORMAT = HTML-CSS - -# When MathJax is enabled you need to specify the location relative to the HTML -# output directory using the MATHJAX_RELPATH option. The destination directory -# should contain the MathJax.js script. For instance, if the mathjax directory -# is located at the same level as the HTML output directory, then -# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax -# Content Delivery Network so you can quickly see the result without installing -# MathJax. However, it is strongly recommended to install a local copy of -# MathJax from http://www.mathjax.org before deployment. -# The default value is: http://cdn.mathjax.org/mathjax/latest. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest - -# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax -# extension names that should be enabled during MathJax rendering. For example -# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_EXTENSIONS = - -# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces -# of code that will be used on startup of the MathJax code. See the MathJax site -# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an -# example see the documentation. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_CODEFILE = - -# When the SEARCHENGINE tag is enabled doxygen will generate a search box for -# the HTML output. The underlying search engine uses javascript and DHTML and -# should work on any modern browser. Note that when using HTML help -# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) -# there is already a search function so this one should typically be disabled. -# For large projects the javascript based search engine can be slow, then -# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to -# search using the keyboard; to jump to the search box use + S -# (what the is depends on the OS and browser, but it is typically -# , /