Merge branch 'ganhead_constructor_validate' of https://github.com/alexpantyukhin/tensorflow into ganhead_constructor_validate

# Conflicts:
#	tensorflow/contrib/gan/python/estimator/python/head_impl.py
This commit is contained in:
apantykhin 2018-06-06 20:08:01 +04:00
commit ef98fc4fb9
3200 changed files with 144262 additions and 61288 deletions

1
.gitignore vendored
View File

@ -27,6 +27,7 @@ Podfile.lock
/tensorflow/contrib/lite/examples/ios/simple/data/*.txt
/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite
xcuserdata/**
/api_init_files_list.txt
# Android
.gradle

View File

@ -1,5 +1,16 @@
# Contributing guidelines
## Pull Request Checklist
Before sending your pull requests, make sure you followed this list.
- Read [contributing guidelines](CONTRIBUTING.md).
- Read [Code of Conduct](CODE_OF_CONDUCT.md).
- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/).
- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution).
- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style).
- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests).
## How to become a contributor and submit your own code
### Contributor License Agreements

View File

@ -5,9 +5,9 @@
-----------------
| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
|-----------------|---------------------|------------------|-------------------|---------------|---------------|
| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
| **`Documentation`** |
|-----------------|
| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) |
**TensorFlow** is an open source software library for numerical computation using
data flow graphs. The graph nodes represent mathematical operations, while
@ -40,15 +40,6 @@ environment to install the nightly TensorFlow build. We support CPU and GPU
packages on Linux, Mac, and Windows.
**Individual whl files**
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/))
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
#### *Try your first TensorFlow program*
```shell
$ python
@ -82,6 +73,30 @@ The TensorFlow project strives to abide by generally accepted best practices in
[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486)
## Continuous build status
### Official Builds
| Build Type | Status | Artifacts |
| --- | --- | --- |
| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
| **Linux XLA** | TBA | TBA |
| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) |
| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) |
### Community Supported Builds
| Build Type | Status | Artifacts |
| --- | --- | --- |
| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA |
| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
## For more information
* [TensorFlow Website](https://www.tensorflow.org)

View File

@ -1,3 +1,62 @@
# Release 1.8.0
## Major Features And Improvements
* Can now pass `tf.contrib.distribute.MirroredStrategy()` to `tf.estimator.RunConfig()` to run an Estimator model on multiple GPUs on one machine.
* Add `tf.contrib.data.prefetch_to_device()`, which supports prefetching to GPU memory.
* Added Gradient Boosted Trees as pre-made Estimators: BoostedTreesClassifier, BoostedTreesRegressor.
* Add 3rd generation pipeline config for Cloud TPUs which improves performance and usability.
* `tf.contrib.bayesflow` is moving out to it's own repo.
* Added `tf.contrib.{proto,rpc}` to allow generic proto parsing and RPC communication<sup>[1](#rpc-issue)</sup>.
## Bug Fixes and Other Changes
* `tf.data`:
* Add `tf.contrib.data.prefetch_to_device`, which enables prefetching dataset elements to GPU memory.
* Add `tf.contrib.data.AUTOTUNE`, which allows the tf.data runtime to automatically tune the prefetch buffer sizes based on your system and environment.
* Add `tf.contrib.data.make_csv_dataset` for building datasets of CSV files.
* Eager Execution:
* With eager execution Datasets can now be used as standard python iterators (`for batch in dataset:`). Both `Dataset.__iter__()` and `Dataset.make_one_shot_iterator()` can now be used to create iterators when eager execution is enabled.
* Automatic device placement has been enabled (i.e., use a GPU if available automatically, without requiring an explicit `with tf.device(“/gpu:0”)`) (Fixes #14133)
* `tf.GradientTape` has moved out of contrib.
* `tf.keras`:
* Added the fashion mnist dataset.
* New data preprocessing functions: `image/random_brightness`, `sequence/TimeseriesGenerator`, and `text/hashing_trick`.
* Accelerated Linear Algebra (XLA):
* Select and scatter in reference util and evaluator now use lexicographical order to break ties.
* TensorFlow Debugger (tfdbg) CLI:
* During tensor-filter operations, allow exclusion of nodes by regular expressions.
* Fix spurious background colors in some text terminals.
* `tf.contrib`:
* Add meta-distribution BatchReshape which reshapes batch dimensions.
* `tf.contrib.layers.recompute_grad` works for explicit gradient checkpointing on TPU.
* Add `tf.contrib.framework.argsort`.
* Allow `DNNBoostedTreeCombinedEstimator` to work with core versions of feature columns and losses.
* Add non-linear image warping ops: `tf.contrib.image.sparse_image_warp`, `tf.contrib.image.dense_image_warp`, and `tf.contrib.image.interpolate_spline`.
* Fix bug in `tf.contrib.opt.MultitaskOptimizerWrapper` where types of tensors were mismatched.
* Other:
* Low-level graph construction now calls the TensorFlow C API. This change should be invisible to most users, but can be disabled by setting the environment variable `TF_C_API_GRAPH_CONSTRUCTION=0` in this release. Future releases will remove the ability to disable this change. Please [file a bug](https://github.com/tensorflow/tensorflow/issues/new) if you find yourself using this escape hatch.
* Add description of shapes and a pointer to tutorial notebook in `tf.distributions.Distribution`.
* Update scatter operations:
* Add `tf.scatter_min` and `tf.scatter_max`
* Extend scatter operations to work with a scalar update parameter.
* Move cuDNN RNN ops to core for use in TensorFlow codebase only.
* Add `float64` support for `Conv2d`, `Conv2dBackpropInput`, and `Conv2dBackpropFilter`.
* Add `float64` support for `AvgPool`/`AvgPoolGrad`.
* Make graph name scope thread local so that they work correctly in multi-threaded environments.
* Update nsync synchronization library to avoid slow primitives on Linux.
* Removed need to put nsync/public on C include path when building custom ops.
* Add `tf.image.psnr`, `tf.image.ssim`, `tf.image.ssim_multiscale`, `tf.image.image_gradients`, `tf.image.sobel_edges`.
* Add links to https://js.tensorflow.org.
* Fix non-uniformity of orthogonal matrices.
* Fix bug where multi-image Estimator eval summaries were not displayed correctly.
<a name="rpc-issue"><sup>1</sup></a> The cancellation logic of the RPC op contains a concurrency error. A fix has been submitted to master and will be part of the next release.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
4d55397500, Aghasy, Alan Du, Alan Lee, Alan Yee, Alex Wiltschko, Animesh Karnewar, Ankit Gupta, Anton Matosov, Aris L, Ben Barsdell, Brent Yi, Brett Koonce, Carl Thomé, cbockman, Chikanaga Tomoyuki, Chris Tava, CéDric Deltheil, Dahan Gong, Dalmo Cirne, Daniel Erenrich, David Norman, DavidNorman, Edd Wilder-James, Fanjin Zeng, Felix Abecassis, fo40225, George Sterpu, Giovanni Terlingen, Gor Baghdasaryan, Guillaume Klein, Hanchen Li, Ilya Polenov, Jakub Kolodziejczyk, Jason Sadler, Jayaram Bobba, Jerry Liu, jinghuangintel, Jiongyan Zhang (张炯衍), Joel Shor, Jong Wook Kim, Julian Eisenschlos, Karl Lessard, Krish Ravindranath, Loo Rong Jie, Lukas Geiger, Luke Iwanski, Mahmoud Abuzaina, ManHyuk, Marvin Richter, Maximilian Mitchell, Mohammad Ashraf Bhuiyan, msofka, Mustafa Kasap, Nathan Burnham, Nathan Luehr, Naveen Marri, ngc92, nio1814, Oleg Zabluda, Ou Changkun, Panos Ipeirotis, Paul Van Eck, Peter Lee, Piotr Czapla, qjivy, Rholais Lii, Rodrigo Formigone, Russell Klopfer, ryantimjohn, Sang Han, SebastiáN RamíRez, shengfuintel, Siby Jose Plathottam, Silver Chan, Stanislaw Antol, Taehoon Lee, Tarang Chugh, Ted Chang, Thomas Bastiani, Xian Xu, Xiaoming (Jason) Cui, Yan Facai (颜发才), yaox12, Yashal Shakti Kanungo, Yong Tang, Yuan (Terry) Tang, Yuxin Wu, Ziyue(Louis) Lu
# Release 1.7.0
## Major Features And Improvements
@ -177,7 +236,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田
* Add `complex64` support to XLA compiler.
* `bfloat` support is now added to XLA infrastructure.
* Make `ClusterSpec` propagation work with XLA devices.
* Use a determinisitic executor to generate XLA graph.
* Use a deterministic executor to generate XLA graph.
* `tf.contrib`:
* `tf.contrib.distributions`:
* Add `tf.contrib.distributions.Autoregressive`.

View File

@ -173,7 +173,7 @@ the progress being made towards a fix and announcement.
In addition, please include the following information along with your report:
* Your name and affiliation (if any).
* A description the technical details of the vulnerabilities. It is very
* A description of the technical details of the vulnerabilities. It is very
important to let us know how we can reproduce your findings.
* An explanation who can exploit this vulnerability, and what they gain when
doing so -- write an attack scenario. This will help us evaluate your report

View File

@ -2,11 +2,11 @@ workspace(name = "org_tensorflow")
http_archive(
name = "io_bazel_rules_closure",
sha256 = "6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657",
strip_prefix = "rules_closure-08039ba8ca59f64248bb3b6ae016460fe9c9914f",
sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", # 2018-01-16
"https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13
],
)

View File

@ -226,8 +226,6 @@ def setup_python(environ_cp):
# Set-up env variables used by python_configure.bzl
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
write_to_bazelrc('build --force_python=py%s' % python_major_version)
write_to_bazelrc('build --host_force_python=py%s' % python_major_version)
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
@ -500,10 +498,6 @@ def set_cc_opt_flags(environ_cp):
if not is_ppc64le() and not is_windows():
write_to_bazelrc('build:opt --host_copt=-march=native')
write_to_bazelrc('build:opt --define with_default_optimizations=true')
# TODO(mikecase): Remove these default defines once we are able to get
# TF Lite targets building without them.
write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK')
write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK')
def set_tf_cuda_clang(environ_cp):
"""set TF_CUDA_CLANG action_env.
@ -847,8 +841,8 @@ def reformat_version_sequence(version_str, sequence_count):
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
ask_cuda_version = (
'Please specify the CUDA SDK version you want to use, '
'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
'Please specify the CUDA SDK version you want to use. '
'[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
# Configure the Cuda SDK version to use.
@ -1228,6 +1222,9 @@ def set_tf_cuda_compute_capabilities(environ_cp):
ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
# Check whether all capabilities from the input is valid
all_valid = True
# Remove all whitespace characters before splitting the string
# that users may insert by accident, as this will result in error
tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split())
for compute_capability in tf_cuda_compute_capabilities.split(','):
m = re.match('[0-9]+.[0-9]+', compute_capability)
if not m:

View File

@ -2097,7 +2097,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
for (int i = 0; i < size; ++i) {
TensorId id = results.missing_unused_input_map_keys[i];
tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
tf_results->missing_unused_key_names_data.push_back(std::string(id.first));
tf_results->missing_unused_key_names[i] =
tf_results->missing_unused_key_names_data.back().c_str();
tf_results->missing_unused_key_indexes[i] = id.second;

View File

@ -184,6 +184,7 @@ library {
return std::move(functions[0]);
}
#if not defined(PLATFORM_WINDOWS)
// On success, returns a set of TF_Function instances encoding a dataset
// node stack that reads a Imagenet TFRecordFile dataset from `file_path`, and
// sets `dataset_name` to the created dataset name. The returned functions must
@ -7076,7 +7077,9 @@ library {
return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status);
#endif
}
#endif
#if not defined(PLATFORM_WINDOWS)
// On success, returns a set of TF_Function instances encoding a dataset
// node stack that reads an MNIST file dataset from `file_path`, and
// sets `dataset_name` to the created dataset name. The returned functions must
@ -8221,6 +8224,7 @@ library {
return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status);
#endif
}
#endif
// Adds the input functions to `graph`. On success, returns the created
// IteratorGetNext node.
@ -8314,6 +8318,13 @@ TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(TF_Graph* graph,
TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
TF_Graph* graph, const char* file_path, int batch_size,
unsigned char is_mnist, TF_Status* status) {
#if defined(PLATFORM_WINDOWS)
// TODO(ashankar): get these functions working on Windows.
status->status = tensorflow::errors::Unimplemented(
"TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API "
"is not implemented for Windows");
return nullptr;
#else
tensorflow::Status s;
std::string dataset_name;
@ -8355,4 +8366,92 @@ TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
<< graph->graph.ToGraphDefDebug().DebugString();
return getnext_node;
#endif
}
TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
TF_Status* status) {
assert(session);
{
tensorflow::mutex_lock c(session->graph->mu);
VLOG(1) << "Dequeuing named tensor with id " << tensor_id
<< ", with input graph: "
<< session->graph->graph.ToGraphDefDebug().DebugString();
}
TF_Operation* dequeue_op = TF_GraphOperationByName(
session->graph,
tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
if (dequeue_op == nullptr) {
status->status = tensorflow::errors::Internal(
"Unable to find the dequeue node in the TF graph.");
return nullptr;
}
VLOG(1) << "Running the dequeue op";
TF_Output output{dequeue_op, 0};
TF_Tensor* ret;
TF_SessionRun(session, /*run_options*/ nullptr,
// input related parameters
/*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
// output related parameters
/*outputs*/ &output, /*output_values*/ &ret,
/*noutputs*/ 1,
/*targets*/ nullptr, /*ntargets*/ 0,
/*run_metadata*/ nullptr, status);
if (VLOG_IS_ON(1) && status->status.ok()) {
tensorflow::Tensor tensor;
if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
}
}
return ret;
}
void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
TF_Tensor* tensor, TF_Status* status) {
assert(session);
{
tensorflow::mutex_lock c(session->graph->mu);
if (VLOG_IS_ON(1)) {
VLOG(1) << "Enqueuing named tensor with id " << tensor_id
<< ", with input graph: "
<< session->graph->graph.ToGraphDefDebug().DebugString();
tensorflow::Tensor internal_tensor;
if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
VLOG(1) << "Enqueu'ing tensor content: "
<< internal_tensor.DebugString();
}
}
}
TF_Operation* enqueue_op = TF_GraphOperationByName(
session->graph,
tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
if (enqueue_op == nullptr) {
status->status = tensorflow::errors::Internal(
"Unable to find the enqueue node in the TF graph.");
return;
}
TF_Operation* placeholder_op = TF_GraphOperationByName(
session->graph,
tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
if (placeholder_op == nullptr) {
status->status = tensorflow::errors::Internal(
"Unable to find the placeholder node as input to enqueue in the TF "
"graph.");
return;
}
VLOG(1) << "Running the enqueue op";
TF_Output input{placeholder_op, 0};
TF_SessionRun(session, /*run_options*/ nullptr,
// input related parameters
/*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
// output related parameters
/*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
/*targets*/ &enqueue_op, /*ntargets*/ 1,
/*run_metadata*/ nullptr, status);
VLOG(1) << "Enqueuing is done.";
}

View File

@ -86,6 +86,35 @@ TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
TF_Graph* graph, const char* file_path, int batch_size,
unsigned char is_mnist, TF_Status* status);
// On success, dequeues a tensor from a TF-managed FifoQueue given by
// `tensor_id`, associated with `session`. There must be a graph node named
// "fifo_queue_dequeue_<tensor_id>", to be executed by this API call.
// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is
// empty, this call is blocked.
//
// Tensors are enqueued via the corresponding TF enqueue op.
// TODO(hongm): Add support for `timeout_ms`.
TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session,
int tensor_id,
TF_Status* status);
// On success, enqueues `tensor` into a TF-managed FifoQueue given by
// `tensor_id`, associated with `session`. There must be a graph node named
// "fifo_queue_enqueue_<tensor_id>", to be executed by this API call. It reads
// from a placeholder node "arg_tensor_enqueue_<tensor_id>".
//
// `tensor` is still owned by the caller. This call will be blocked if the queue
// has reached its capacity, and will be unblocked when the queued tensors again
// drop below the capacity due to dequeuing.
//
// Tensors are dequeued via the corresponding TF dequeue op.
// TODO(hongm): Add support for `timeout_ms`.
TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
int tensor_id,
TF_Tensor* tensor,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -1368,7 +1368,7 @@ TEST(CAPI, SavedModel) {
}
const tensorflow::string input_op_name =
tensorflow::ParseTensorName(input_name).first.ToString();
std::string(tensorflow::ParseTensorName(input_name).first);
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
@ -1376,7 +1376,7 @@ TEST(CAPI, SavedModel) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
const tensorflow::string output_op_name =
tensorflow::ParseTensorName(output_name).first.ToString();
std::string(tensorflow::ParseTensorName(output_name).first);
TF_Operation* output_op =
TF_GraphOperationByName(graph, output_op_name.c_str());
ASSERT_TRUE(output_op != nullptr);
@ -1700,7 +1700,7 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
// REGISTER_OP for CApiTestAttributesTest test cases.
// REGISTER_OP for CApiAttributesTest test cases.
// Registers two ops, each with a single attribute called 'v'.
// The attribute in one op will have a type 'type', the other
// will have list(type).

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"

View File

@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() {
const auto& slice_proto = entry.slices(i);
CHECK(filtered_keys
.insert(EncodeTensorNameSlice(
v2_reader_->key().ToString() /* full var's name */,
std::string(v2_reader_->key()) /* full var's name */,
TensorSlice(slice_proto)))
.second);
}
@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() {
new TensorSliceReader::VarToDataTypeMap);
v2_reader_->Seek(kHeaderEntryKey);
for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue;
if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue;
CHECK(entry.ParseFromArray(v2_reader_->value().data(),
v2_reader_->value().size()))
<< entry.InitializationErrorString();
string key = v2_reader_->key().ToString();
string key = std::string(v2_reader_->key());
(*var_to_shape_map)[key] = TensorShape(entry.shape());
(*var_to_data_type_map)[key] = DataType(entry.dtype());
}

View File

@ -24,14 +24,13 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":runtime",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:execute_node",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
@ -49,6 +48,18 @@ tf_cuda_library(
],
"//conditions:default": [],
}) + [
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core:gpu_runtime",
],
)
@ -59,7 +70,6 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":c_api",
":runtime",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
@ -69,10 +79,23 @@ tf_cuda_library(
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
],
)
@ -91,47 +114,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_library(
name = "runtime",
srcs = ["runtime.cc"],
hdrs = ["runtime.h"],
copts = tf_copts(),
visibility = ["//tensorflow:internal"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/c:c_api",
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
}),
)
tf_cc_test(
name = "runtime_test",
srcs = ["runtime_test.cc"],
deps = [
":runtime",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
],
)

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/runtime.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
@ -32,16 +31,22 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
@ -72,6 +77,121 @@ string DeviceName(const tensorflow::Device* d) {
std::atomic_int_fast64_t func_id_generator(0);
#endif // TENSORFLOW_EAGER_USE_XLA
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
std::vector<tensorflow::Device*> remote_devices;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) {
tensorflow::Notification n;
tensorflow::NewRemoteDevices(
tensorflow::Env::Default(), worker_cache, remote_worker,
[&status, &n, &remote_devices](
const tensorflow::Status& s,
std::vector<tensorflow::Device*>* devices) {
status = s;
if (s.ok()) {
for (tensorflow::Device* d : *devices) {
remote_devices.push_back(d);
}
}
n.Notify();
});
n.WaitForNotification();
}
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
new tensorflow::DeviceMgr(remote_devices));
TF_RETURN_IF_ERROR(status);
*device_mgr = std::move(remote_device_mgr);
return tensorflow::Status::OK();
}
tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) {
const string& remote_worker = remote_workers[i];
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse response;
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
return tensorflow::errors::InvalidArgument(
"Unable to parse ", remote_worker, " as a device name");
}
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
if (eager_client == nullptr) {
return tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
tensorflow::Notification n;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
eager_client->CreateContextAsync(
&request, &response, [&status, &n](const tensorflow::Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(status);
remote_contexts->emplace(remote_worker, response.context_id());
}
return tensorflow::Status::OK();
}
tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
TFE_Context** ctx) {
string worker_name = tensorflow::strings::StrCat(
"/job:", opts->server_def.job_name(),
"/replica:0/task:", opts->server_def.task_index());
std::unique_ptr<tensorflow::eager::EagerGrpcServer> server;
TF_RETURN_IF_ERROR(
tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server));
TF_RETURN_IF_ERROR(server->Start());
std::vector<string> remote_workers;
server->master_env()->worker_cache->ListWorkers(&remote_workers);
remote_workers.erase(
std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
remote_workers.end());
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
TF_RETURN_IF_ERROR(GetAllRemoteDevices(
remote_workers, server->master_env()->worker_cache, &remote_device_mgr));
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
server->channel_cache();
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers,
remote_eager_workers.get(),
opts->async, &remote_contexts));
tensorflow::RemoteRendezvous* r =
server->worker_env()->rendezvous_mgr->Find(0);
auto* device_mgr = server->worker_env()->device_mgr;
*ctx = new TFE_Context(opts->session_options.options, opts->policy,
opts->async, device_mgr, r, std::move(server),
std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts);
return tensorflow::Status::OK();
}
} // namespace
extern "C" {
@ -92,6 +212,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
options->policy = policy;
}
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
TFE_ContextOptions* options, const void* proto, size_t proto_len,
TF_Status* status) {
if (!options->server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
}
}
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
@ -101,28 +230,35 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (!opts->server_def.job_name().empty()) {
TFE_Context* ctx = nullptr;
status->status = NewRemoteAwareTFE_Context(opts, &ctx);
return ctx;
}
std::vector<tensorflow::Device*> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices);
if (!status->status.ok()) {
return nullptr;
}
if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options, opts->policy,
opts->async, std::move(device_mgr), r);
}
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
delete ctx;
}
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
ctx->context.device_mgr()->ListDeviceAttributes(&list->response);
ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context.remote_device_mgr()) {
ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
}
return list;
}
@ -220,9 +356,6 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
}
return retval;
}
} // extern "C"
extern "C" {
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
@ -242,21 +375,18 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
tensorflow::Device* d = nullptr;
if (device_name != nullptr && strlen(device_name) > 0) {
status->status = op->ctx->context.FindDeviceByName(device_name, &d);
}
op->device = d;
status->status = op->operation.SetDevice(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device =
(op->device == nullptr) ? op->ctx->context.HostCPU() : op->device;
tensorflow::Device* device = (op->operation.Device() == nullptr)
? op->operation.EagerContext()->HostCPU()
: op->operation.Device();
return device->name().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
op->use_xla = enable;
op->operation.SetUseXla(enable);
#ifndef TENSORFLOW_EAGER_USE_XLA
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
@ -264,22 +394,20 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
h->handle->Ref();
op->inputs.push_back(h->handle);
op->attrs.NumInputs(op->inputs.size());
op->operation.AddInput(h->handle);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret;
if (op->is_function()) {
if (op->operation.is_function()) {
status->status = tensorflow::errors::Unimplemented(
"TODO(apassos): Support for attributes for TensorFlow functions is not "
"ready yet.");
return TF_ATTR_INT; // The compiler requires that we return something.
}
status->status =
tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
attr_name, &ret, is_list);
return ret;
}
@ -298,23 +426,24 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
}
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
op->attrs.Set(attr_name, value);
op->operation.MutableAttrs()->Set(attr_name, value);
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
op->attrs.Set(attr_name, static_cast<int64>(value));
op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
op->attrs.Set(attr_name, value);
op->operation.MutableAttrs()->Set(attr_name, value);
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
op->attrs.Set(attr_name, (value == 0) ? false : true);
op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
op->operation.MutableAttrs()->Set(attr_name,
static_cast<tensorflow::DataType>(value));
}
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
@ -336,23 +465,24 @@ void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
proto.add_dim()->set_size(dims[d]);
}
}
op->attrs.Set(attr_name, proto);
op->operation.MutableAttrs()->Set(attr_name, proto);
}
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(value->name);
value->attrs.FillAttrValueMap(func->mutable_attr());
op->attrs.Set(attr_name, attr_value);
func->set_name(value->operation.Name());
value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
#define TFE_OP_SET_ATTR_LIST(fn, type) \
void fn(TFE_Op* op, const char* attr_name, const type* values, \
int num_values) { \
op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
values, num_values)); \
op->operation.MutableAttrs()->Set( \
attr_name, \
tensorflow::gtl::ArraySlice<const type>(values, num_values)); \
}
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
@ -360,14 +490,14 @@ TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
op->attrs.Set(attr_name,
tensorflow::gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
}
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
op->attrs.Set(
op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
@ -379,8 +509,8 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
op->attrs.Set(attr_name,
tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
}
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
@ -410,9 +540,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
}
}
}
op->attrs.Set(attr_name,
tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
proto.get(), num_values));
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
proto.get(), num_values));
}
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
@ -420,534 +550,25 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
new tensorflow::NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
funcs[i].set_name(value[i]->name);
value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr());
funcs[i].set_name(value[i]->operation.Name());
value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
}
op->attrs.Set(attr_name,
tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
funcs.get(), num_values));
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
funcs.get(), num_values));
}
} // extern "C"
namespace {
// Initializes the step stats if needed.
void MaybeInitializeStepStats(tensorflow::StepStats* step_stats,
tensorflow::EagerContext* ctx) {
// Lazily initialize the RunMetadata with information about all devices if
// this is the first call.
while (step_stats->dev_stats_size() < ctx->devices()->size()) {
int device_idx = step_stats->dev_stats_size();
auto* dev_stats = step_stats->add_dev_stats();
dev_stats->set_device(ctx->devices()->at(device_idx)->name());
}
}
int StepStatsDeviceIndex(tensorflow::StepStats* step_stats,
tensorflow::EagerContext* ctx,
tensorflow::Device* device) {
// Find the current device's index.
if (device == nullptr) {
device = ctx->HostCPU();
}
for (int i = 0; i < ctx->devices()->size(); ++i) {
if (ctx->devices()->at(i) == device ||
ctx->devices()->at(i)->name() == device->name()) {
return i;
}
}
// TODO(apassos) do not fall back to host CPU if device is unknown.
return 0;
}
tensorflow::Status ValidateInputTypeAndPlacement(
tensorflow::EagerContext* ctx, tensorflow::Device* op_device, TFE_Op* op,
const tensorflow::OpKernel* kernel, tensorflow::RunMetadata* run_metadata) {
tensorflow::Device* host_device = ctx->HostCPU();
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
if (memtypes.size() != op->inputs.size()) {
return tensorflow::errors::InvalidArgument(
"expected ", memtypes.size(), " inputs, got ", op->inputs.size());
}
for (int i = 0; i < op->inputs.size(); ++i) {
const tensorflow::Device* expected_device =
memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
tensorflow::TensorHandle* handle = op->inputs[i];
tensorflow::Device* handle_device = nullptr;
TF_RETURN_IF_ERROR(handle->Device(&handle_device));
const tensorflow::Device* actual_device =
handle_device == nullptr ? host_device : handle_device;
if (expected_device != actual_device) {
switch (ctx->GetDevicePlacementPolicy()) {
case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32:
// TODO(xpan): See if we could bubble python related error up
// to python level.
if (handle->dtype == tensorflow::DT_INT32) {
// Note: enabling silent copies of int32 tensors to match behavior
// of graph mode.
break;
}
TF_FALLTHROUGH_INTENDED;
case tensorflow::DEVICE_PLACEMENT_EXPLICIT:
return tensorflow::errors::InvalidArgument(
"Tensors on conflicting devices:"
" cannot compute ",
op->name, " as input #", i, " was expected to be on ",
expected_device->name(), " but is actually on ",
actual_device->name(), " (operation running on ",
op_device->name(), ")",
" Tensors can be copied explicitly using .gpu() or .cpu() "
"methods,"
" or transparently copied by using tf.enable_eager_execution("
"device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
"between devices"
" may slow down your model");
case tensorflow::DEVICE_PLACEMENT_WARN:
LOG(WARNING) << "before computing " << op->name << " input #" << i
<< " was expected to be on " << expected_device->name()
<< " but is actually on " << actual_device->name()
<< " (operation running on " << op_device->name()
<< "). This triggers a copy which can be a performance "
"bottleneck.";
break;
case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing.
break;
}
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
auto pre_time = tensorflow::Env::Default()->NowMicros();
tensorflow::TensorHandle* copied_tensor = nullptr;
tensorflow::Status status = tensorflow::EagerCopyToDevice(
handle, ctx, expected_device->name().c_str(), &copied_tensor);
if (run_metadata != nullptr) {
auto* step_stats = run_metadata->mutable_step_stats();
MaybeInitializeStepStats(step_stats, ctx);
// Record the sending on the source device for now.
int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
auto* node_stats = dev_stats->add_node_stats();
node_stats->set_node_name("_Send");
node_stats->set_all_start_micros(pre_time);
node_stats->set_op_end_rel_micros(
tensorflow::Env::Default()->NowMicros() - pre_time);
}
if (!status.ok()) {
if (copied_tensor != nullptr) copied_tensor->Unref();
return tensorflow::errors::Internal(
"Failed copying input tensor from ", actual_device->name(), " to ",
expected_device->name(), " in order to run ", op->name, ": ",
status.error_message());
}
handle->Unref();
handle = copied_tensor;
op->inputs[i] = copied_tensor;
}
if (handle->dtype != kernel->input_type(i)) {
return tensorflow::errors::InvalidArgument(
"cannot compute ", op->name, " as input #", i,
" was expected to be a ",
tensorflow::DataTypeString(kernel->input_type(i)),
" tensor but is a ", tensorflow::DataTypeString(handle->dtype),
" tensor");
}
}
return tensorflow::Status::OK();
}
tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
TFE_Context* ctx, TF_Status* status) {
tensorflow::DeviceSet ds;
for (tensorflow::Device* d : *ctx->context.devices()) {
ds.AddDevice(d);
}
tensorflow::DeviceTypeVector final_devices;
status->status = tensorflow::SupportedDeviceTypesForNode(
ds.PrioritizedDeviceTypeList(), ndef, &final_devices);
if (!status->status.ok()) {
return nullptr;
}
if (final_devices.empty()) {
status->status = tensorflow::errors::Internal(
"Could not find valid device for node ", ndef.DebugString());
return nullptr;
}
for (tensorflow::Device* d : *ctx->context.devices()) {
if (d->device_type() == final_devices[0].type_string()) {
return d;
}
}
status->status = tensorflow::errors::Unknown(
"Could not find a device for node ", ndef.DebugString());
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
// Synthesizes and returns a wrapper function over `op`, which must be a
// primitive op (e.g. matmul).
//
// The wrapper function conforms to the function signature expected by
// _XlaLaunchOp, with input params ordered by <constants, (variable) args and
// resources>. For example, if the op has input params <Const1, Arg2, Const3,
// Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
// Resource4> as the input params to the synthesized function.
//
// It populates `const_input_types`, `arg_input_types` and
// `op_input_to_func_input` based on the reordering results, that the caller can
// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
// `status` accordingly.
const tensorflow::FunctionDef* OpToFunction(
TFE_Op* op, std::vector<TF_DataType>* const_input_types,
std::vector<TF_DataType>* arg_input_types,
tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input,
TF_Status* status) {
DCHECK(!op->is_function());
tensorflow::FunctionDef fdef;
// Get the OpDef of the op we are trying to encapsulate.
TFE_Context* ctx = op->ctx;
const tensorflow::OpRegistrationData* op_data;
{
status->status = ctx->context.FindFunctionOpData(op->name, &op_data);
if (!status->status.ok()) {
return nullptr;
}
}
const tensorflow::OpDef& op_def = op_data->op_def;
tensorflow::OpDef* signature = fdef.mutable_signature();
// Handle constant inputs.
const std::unordered_set<string> const_inputs(
*tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name));
// First add place holders for the input args, so that we can refer to them by
// position in the next loop. Also tally up the resource inputs.
int num_resource_inputs = 0;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) {
++num_resource_inputs;
}
signature->add_input_arg();
}
// Now we map the input params from `op_def` to `signature`, where the param
// ordering for `signature` is: <constants, args, resources>.
int const_index = 0;
int arg_index = const_inputs.size();
int resource_index = op_def.input_arg_size() - num_resource_inputs;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
tensorflow::OpDef::ArgDef* func_input_arg = nullptr;
if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
VLOG(1) << "For const input, mapping op input " << i << " to func input "
<< const_index;
(*op_input_to_func_input)[i] = const_index;
func_input_arg = signature->mutable_input_arg(const_index++);
const_input_types->push_back(
static_cast<TF_DataType>(op->inputs[i]->dtype));
} else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
VLOG(1) << "For resource input, mapping op input " << i
<< " to func input " << resource_index;
(*op_input_to_func_input)[i] = resource_index;
func_input_arg = signature->mutable_input_arg(resource_index++);
} else {
VLOG(1) << "For arg input, mapping op input " << i << " to func input "
<< arg_index;
(*op_input_to_func_input)[i] = arg_index;
func_input_arg = signature->mutable_input_arg(arg_index++);
arg_input_types->push_back(
static_cast<TF_DataType>(op->inputs[i]->dtype));
}
func_input_arg->set_name(op_input_arg.name());
func_input_arg->set_type(op->inputs[i]->dtype);
}
VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
// Resources args are at the end of the function input params, and we should
// have iterated over all of them.
DCHECK_EQ(signature->input_arg_size(), resource_index);
// Make the synthesized function's name unique.
signature->set_name(tensorflow::strings::StrCat(
op_def.name(), func_id_generator.fetch_add(1)));
// Add the node def and set its input names to match op_def's names.
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
*fdef.add_node_def() = ndef;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
}
VLOG(1) << "Added NodeDef: " << fdef.DebugString();
// Fix the output names and set output types.
for (int i = 0; i < op_def.output_arg_size(); ++i) {
tensorflow::OpDef::ArgDef* arg = signature->add_output_arg();
const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
const string& out_tensor_name = tensorflow::strings::StrCat(
ndef.name(), ":", op_def_arg.name(), ":", 0);
arg->set_name(op_def_arg.name());
(*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
const string& type_attr = op_def_arg.type_attr();
if (!type_attr.empty()) {
auto i = ndef.attr().find(type_attr);
if (i == ndef.attr().end()) {
status->status = tensorflow::errors::InvalidArgument(
tensorflow::strings::StrCat("Could not find attr ", type_attr,
" in NodeDef ", ndef.DebugString()));
return nullptr;
}
arg->set_type(i->second.type());
}
}
VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
status->status = ctx->context.AddFunctionDef(fdef);
if (!status->status.ok()) return nullptr;
const auto ret = ctx->context.FindFunctionDef(signature->name());
DCHECK(ret != nullptr);
return ret;
}
// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
// via XLA.
std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name;
auto launch_op =
std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
if (op->device) {
TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
const tensorflow::FunctionDef* fdef;
{
fdef = op->ctx->context.FindFunctionDef(op->name);
}
std::vector<TF_DataType> const_input_types;
std::vector<TF_DataType> arg_input_types;
tensorflow::gtl::FlatMap<int, int> op_input_to_func_input;
if (fdef == nullptr) {
// See if this is a primitive op, and if so create a function for it, so
// that _XlaLaunchOp can access it.
fdef = OpToFunction(op, &const_input_types, &arg_input_types,
&op_input_to_func_input, status);
if (!status->status.ok()) return nullptr;
} else {
// TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
// functions, so we need to find another way to handle constant inputs.
for (int i = const_input_types.size();
i < fdef->signature().input_arg_size(); ++i) {
VLOG(1) << "Adding Targs from input arg " << i;
const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i);
arg_input_types.push_back(static_cast<TF_DataType>(arg.type()));
}
}
DCHECK(fdef != nullptr);
// Copy inputs and their devices.
// Since input param reordering may have occurred between `op` and `launch_op`
// via `op_input_to_func_input`, adjust the actual inputs accordingly.
launch_op->inputs = op->inputs;
for (tensorflow::TensorHandle* h : launch_op->inputs) {
h->Ref();
}
if (!op_input_to_func_input.empty()) {
DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
for (int i = 0; i < op_input_to_func_input.size(); ++i) {
VLOG(1) << "mapping op input " << i << " to func input "
<< op_input_to_func_input[i];
launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
}
}
launch_op->attrs.NumInputs(op->inputs.size());
TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
const_input_types.size());
// Set Targs and Nresources attrs.
TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
arg_input_types.size());
const int num_resource_inputs = fdef->signature().input_arg_size() -
const_input_types.size() -
arg_input_types.size();
TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);
// Set Tresults attr.
std::vector<TF_DataType> tresults;
for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) {
tresults.push_back(static_cast<TF_DataType>(arg.type()));
}
TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
tresults.size());
// Set function attr.
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(fdef->signature().name());
launch_op->attrs.Set("function", attr_value);
return launch_op;
}
#endif // TENSORFLOW_EAGER_USE_XLA
} // namespace
extern "C" {
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
TFE_Context* ctx = op->ctx;
status->status = ctx->context.GetStatus();
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
*num_retvals);
status->status =
tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals);
if (!status->status.ok()) {
return;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
std::unique_ptr<TFE_Op> xla_launch_op;
if (op->use_xla && op->name != "_XlaLaunch") {
xla_launch_op = BuildXlaLaunch(op, status);
if (!status->status.ok()) {
return;
}
op = xla_launch_op.get();
}
#endif // TENSORFLOW_EAGER_USE_XLA
// Ensure all resource-touching ops run in the device the resource is,
// regardless of anything else that has been specified. This is identical to
// the graph mode behavior.
for (int i = 0; i < op->inputs.size(); ++i) {
tensorflow::Device* input_op_device = nullptr;
status->status = op->inputs[i]->OpDevice(&input_op_device);
if (!status->status.ok()) return;
VLOG(2) << "for op " << op->name << " input " << i << " "
<< tensorflow::DataTypeString(op->inputs[i]->dtype) << " "
<< (input_op_device == nullptr ? "cpu" : input_op_device->name())
<< " " << (op->device == nullptr ? "cpu" : op->device->name());
if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE &&
(input_op_device != op->device || input_op_device == nullptr)) {
tensorflow::Device* d =
input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device;
VLOG(1) << "Changing device of operation " << op->name << " to "
<< d->name() << " because input #" << i
<< " is a resource in this device.";
op->device = d;
}
}
tensorflow::Device* device = op->device;
tensorflow::Fprint128 cache_key =
op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name());
tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key);
if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
if (device == nullptr) {
device = SelectDevice(ndef, ctx, status);
if (!status->status.ok()) {
return;
}
}
CHECK(device != nullptr);
if (ctx->context.LogDevicePlacement()) {
LOG(INFO) << "Executing op " << ndef.op() << " in device "
<< device->name();
}
kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous());
// Knowledge of the implementation of Init (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
// See WARNING comment in Execute (before kernel->Run) - would be nice to
// rework to avoid this subtlety.
tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu());
status->status = tensorflow::KernelAndDevice::Init(
ndef, ctx->context.func_lib(device), kernel);
if (!status->status.ok()) {
delete kernel;
return;
}
// Update output_dtypes inside `kernel`.
const tensorflow::OpDef* op_def = nullptr;
const tensorflow::FunctionDef* function_def =
ctx->context.FuncLibDef()->Find(ndef.op());
if (function_def != nullptr) {
op_def = &(function_def->signature());
}
if (op_def == nullptr) {
status->status = OpDefForOp(ndef.op().c_str(), &op_def);
if (!status->status.ok()) {
return;
}
}
tensorflow::DataTypeVector input_dtypes;
status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes,
kernel->mutable_output_dtypes());
if (!status->status.ok()) {
return;
}
ctx->context.AddKernelToCache(cache_key, kernel);
}
const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
const int output_dtypes_size = output_dtypes.size();
if (output_dtypes_size > *num_retvals) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting ", output_dtypes.size(),
" outputs, but *num_retvals is ",
*num_retvals)
.c_str());
return;
}
*num_retvals = output_dtypes_size;
if (device == nullptr) {
// TODO(apassos) debug how the assignment below might return a different
// device from the one requested above.
device = kernel->device();
}
status->status = ValidateInputTypeAndPlacement(
&ctx->context, device, op, kernel->kernel(),
ctx->context.ShouldStoreMetadata() ? ctx->context.RunMetadataProto()
: nullptr);
if (!status->status.ok()) return;
std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
if (ctx->context.ShouldStoreMetadata()) {
maybe_stats.reset(new tensorflow::NodeExecStats);
maybe_stats->set_node_name(op->name);
maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
maybe_stats->set_op_start_rel_micros(0);
maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
// TODO(apassos) track referenced tensors
}
if (ctx->context.Async()) {
// Note that for async mode, execution order will make sure that all
// input handles are ready before executing them.
// TODO(agarwal): Consider executing "cheap" kernels inline for performance.
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
*num_retvals);
tensorflow::uint64 id = op->ctx->context.NextId();
for (int i = 0; i < *num_retvals; ++i) {
tensorflow::TensorHandle* h =
new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context);
retvals[i] = new TFE_TensorHandle(h);
handle_retvals[i] = h;
}
tensorflow::EagerNode* node = new tensorflow::ExecuteNode(
id, &op->ctx->context, op->device, op->inputs, kernel,
maybe_stats.release(), output_dtypes, handle_retvals);
ctx->context.ExecutorAdd(node);
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it.
std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals,
nullptr);
status->status = tensorflow::EagerExecute(
&op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(),
handle_retvals.data(), *num_retvals);
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
}
}
@ -1090,10 +711,3 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
}
}
} // namespace tensorflow
TFE_Op::~TFE_Op() {
for (tensorflow::TensorHandle* h : inputs) {
h->Unref();
}
}

View File

@ -81,6 +81,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
// A tensorflow.ServerDef specifies remote workers (in addition to the current
// workers name). Operations created on this context can then be executed on
// any of these remote workers by setting an appropriate device.
//
// If the following is set, all servers identified by the
// ServerDef must be up when the context is created.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
TFE_ContextOptions* options, const void* proto, size_t proto_len,
TF_Status* status);
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);

View File

@ -28,14 +28,23 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/runtime.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -45,12 +54,12 @@ limitations under the License.
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
struct TFE_ContextOptions {
TF_SessionOptions session_options;
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
tensorflow::ServerDef server_def;
};
struct TFE_Context {
@ -64,6 +73,23 @@ struct TFE_Context {
default_policy),
async, std::move(device_mgr), rendezvous) {}
explicit TFE_Context(
const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_policy, bool async,
tensorflow::DeviceMgr* local_device_mgr,
tensorflow::Rendezvous* rendezvous,
std::unique_ptr<tensorflow::GrpcServer> server,
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr,
const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>&
remote_contexts)
: context(opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
async, local_device_mgr, rendezvous, std::move(server),
std::move(remote_eager_workers), std::move(remote_device_mgr),
remote_contexts) {}
tensorflow::EagerContext context;
};
@ -85,19 +111,9 @@ struct TFE_Op {
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
// primitive operation.
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
: ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
: operation(&ctx->context, op, t) {}
~TFE_Op();
bool const is_function() const { return attr_types == nullptr; }
TFE_Context* ctx; // Must outlive the TFE_Op.
const tensorflow::string name;
tensorflow::AttrBuilder attrs;
const tensorflow::AttrTypeMap* attr_types;
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs;
tensorflow::Device* device;
bool use_xla = false;
tensorflow::EagerOperation operation;
};
namespace tensorflow {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include <string.h>
#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@ -23,7 +24,9 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using tensorflow::string;
@ -220,6 +223,103 @@ TEST(CAPI, Context) {
TF_DeleteStatus(status);
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name("localhost");
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name("localhost");
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
ASSERT_TRUE(
tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
TFE_DeleteContext(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
// TODO(nareshmodi): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));

View File

@ -130,13 +130,15 @@ class GradientTape {
}
}
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes);
void Watch(int64 tensor_id);
void RecordOperation(const string& op_type,
gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter);
@ -170,12 +172,32 @@ class GradientTape {
// Template instantiations here
inline bool IsDtypeTrainable(DataType dtype) {
switch (dtype) {
case DT_HALF:
case DT_BFLOAT16:
case DT_FLOAT:
case DT_DOUBLE:
case DT_COMPLEX64:
case DT_COMPLEX128:
case DT_RESOURCE:
case DT_VARIANT:
return true;
default:
return false;
}
}
template <typename Gradient, typename BackwardFunction>
bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids) {
for (int64 i : tensor_ids) {
if (tensor_tape_.find(i) != tensor_tape_.end()) {
return true;
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
CHECK_EQ(tensor_ids.size(), dtypes.size());
for (int i = 0; i < tensor_ids.size(); ++i) {
if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
if (IsDtypeTrainable(dtypes[i])) {
return true;
}
}
}
return false;
@ -189,9 +211,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
template <typename Gradient, typename BackwardFunction>
void GradientTape<Gradient, BackwardFunction>::RecordOperation(
const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
const std::function<void()>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id)) {
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
backward_function_deleter();
return;
}
@ -380,49 +404,39 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
gtl::ArraySlice<Gradient*> output_gradients,
const TensorTape& tensor_tape,
const OpTape<BackwardFunction>& op_tape,
const gtl::FlatMap<int64, int64>& tensor_usage_counts,
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
if (!output_gradients.empty() && output_gradients[i] != nullptr) {
// TODO(apassos) figure out how to print debugging information here.
return errors::InvalidArgument(
"A gradient was provided for a tensor which is used as part of the "
"computation.");
}
} else {
if (output_gradients.empty() || output_gradients[i] == nullptr) {
auto tensor_it = tensor_tape.find(id);
if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
auto op_it = op_tape.find(tensor_it->second);
if (op_it == op_tape.end()) {
return errors::Internal(
"Internal state of the gradient tape is invalid: "
"failed to find operation producing a tensor");
if (output_gradients.empty() || output_gradients[i] == nullptr) {
auto tensor_it = tensor_tape.find(id);
if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
auto op_it = op_tape.find(tensor_it->second);
if (op_it == op_tape.end()) {
return errors::Internal(
"Internal state of the gradient tape is invalid: "
"failed to find operation producing a tensor");
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
if (op_it->second.output_tensor_info[j].id == id) {
found = true;
(*result)[id].push_back(
vspace.Ones(op_it->second.output_tensor_info[j].shape,
op_it->second.output_tensor_info[j].dtype));
break;
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
if (op_it->second.output_tensor_info[j].id == id) {
found = true;
(*result)[id].push_back(
vspace.Ones(op_it->second.output_tensor_info[j].shape,
op_it->second.output_tensor_info[j].dtype));
break;
}
}
if (!found) {
return errors::Internal(
"Internal state of the gradient tape is invalid: "
"none of operations outputs match expected tensor");
}
} else {
// No record of the target tensor found on the tape, so no gradient
// needs to be computed from it. Do nothing.
}
if (!found) {
return errors::Internal(
"Internal state of the gradient tape is invalid: "
"none of operations outputs match expected tensor");
}
} else {
(*result)[id].push_back(output_gradients[i]);
// No record of the target tensor found on the tape, so no gradient
// needs to be computed from it. Do nothing.
}
} else {
(*result)[id].push_back(output_gradients[i]);
}
}
return Status::OK();
@ -451,8 +465,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
tensor_tape_, state.op_tape,
state.tensor_usage_counts, &gradients);
tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
if (!persistent_) {
// Release all backprop functions

View File

@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
@ -135,4 +135,30 @@ std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
return result;
}
void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
const void* proto, size_t proto_len,
TF_Status* status) {
tensorflow::CppShapeInferenceResult::HandleData handle_data;
if (!handle_data.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Couldn't deserialize HandleData proto");
return;
}
DCHECK(handle_data.is_set());
tensorflow::mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&output.oper->node);
std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
tensorflow::shape_inference::ShapeHandle shape;
status->status =
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
if (status->status.ok()) return;
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
}
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
}
} // namespace tensorflow

View File

@ -55,9 +55,15 @@ void ExtendSession(TF_Session* session, TF_Status* status);
// Returns the serialized CppShapeInferenceResult::HandleData proto for
// `output` if its a resource tensor, or otherwise returns the empty string.
// TODO(b/74620627): remove when _USE_C_SHAPES is removed
std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
// Sets `output` based on `proto`, which should be a serialized
// CppShapeInferenceResult::HandleData proto.
// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
// because I couldn't get SWIG to work otherwise.
void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
const void* proto, size_t proto_len,
TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_

View File

@ -440,7 +440,7 @@ string AvoidCPPKeywords(StringPiece name) {
if (IsCPPKeyword(name)) {
return strings::StrCat(name, "_");
}
return name.ToString();
return std::string(name);
}
void InferArgAttributes(const OpDef::ArgDef& arg,

View File

@ -220,7 +220,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
current_constraints.insert(s.ToString());
current_constraints.insert(std::string(s));
}
}
} else {

View File

@ -385,6 +385,42 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
Input x = Shape(scope, op.input(0));
Input begin = op.input(1);
Input end = op.input(2);
Input strides = op.input(3);
int64 begin_mask;
int64 end_mask;
int64 ellipsis_mask;
int64 new_axis_mask;
int64 shrink_axis_mask;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
grad_outputs->push_back(
StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
StridedSliceGrad::BeginMask(begin_mask)
.EndMask(end_mask)
.EllipsisMask(ellipsis_mask)
.NewAxisMask(new_axis_mask)
.ShrinkAxisMask(shrink_axis_mask)));
// No gradients returned for begin, end and strides
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
} // anonymous namespace
} // namespace ops
} // namespace tensorflow

View File

@ -354,5 +354,29 @@ TEST_F(ArrayGradTest, MirrorPadGradGrad_Symmetric) {
RunTest(x, x_shape, y, y_shape);
}
TEST_F(ArrayGradTest, StridedSliceGrad) {
TensorShape x_shape({6, 4, 4});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
// y = x[2:6:2, 1:3, 1:3]
auto y = StridedSlice(scope_, x, {2, 1, 1}, {6, 3, 3}, {2, 1, 1});
// y.shape = [2, 2, 2];
RunTest(x, x_shape, y, {2, 2, 2});
// y = x[2:6:2, 1:3, 1:3]
// begin_mask = 1<<1 (ignore begin_index = 1)
// end_mask = 1<<2 (ignore end_index = 2)
y = StridedSlice(scope_, x, {2, 1, 1}, {6, 3, 3}, {2, 1, 1},
StridedSlice::BeginMask(1 << 1).EndMask(1 << 2));
// y.shape = [2, 3, 3];
RunTest(x, x_shape, y, {2, 3, 3});
// y = [tf.newaxis, 2:6:2, 1:3, 1:3]
y = StridedSlice(scope_, x, {0, 2, 1, 1}, {0, 6, 3, 3}, {1, 2, 1, 1},
StridedSlice::NewAxisMask(1 << 0));
// y.shape = [1, 2, 2, 2];
RunTest(x, x_shape, y, {1, 2, 2, 2});
}
} // namespace
} // namespace tensorflow

View File

@ -31,7 +31,6 @@ using ops::AddN;
using ops::BatchMatMul;
using ops::Const;
using ops::Div;
using ops::Greater;
using ops::MatMul;
using ops::Max;
using ops::Maximum;
@ -46,7 +45,6 @@ using ops::RealDiv;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
using ops::Where3;
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/cc/tools/freeze_saved_model.h"
#include <iostream>
#include <queue>
#include "tensorflow/core/framework/attr_value.pb.h"
@ -71,6 +72,15 @@ void GetNodeNameToNodeDefMap(
}
}
// Strips off the tensor part of the tensor_name to get the node_name.
const string GetNodeNameFromTensorName(string tensor_name) {
if (tensor_name[0] == '^') {
tensor_name.erase(0, 1);
}
std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
return tensor_name_parts[0];
}
// Gets the set of node names needed by `outputs` and the corresponding set of
// variable nodes to convert.
void GetReachableNodesAndVariables(
@ -83,10 +93,8 @@ void GetReachableNodesAndVariables(
new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
std::queue<string> nodes_to_visit;
for (const string& tensor_name : outputs) {
// We need to strip off the tensor part to get the node name.
std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
nodes_to_visit.push(tensor_name_parts[0]);
for (const string& output_tensor_name : outputs) {
nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
}
// We do a traversal backwards from the outputs specified in the MetaGraphDef.
while (!nodes_to_visit.empty()) {
@ -100,8 +108,8 @@ void GetReachableNodesAndVariables(
if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
variable_node_names->insert(node->name());
}
for (const string& input : node->input()) {
nodes_to_visit.push(input);
for (const string& input_tensor_name : node->input()) {
nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
}
}
}

View File

@ -351,6 +351,56 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) {
GraphDefEqual(frozen_graph_def, graph_def);
}
TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) {
// Tensors from operations with multiple outputs get tensor suffixes when used
// in input fields of following nodes, i.e. split:0, split:1.
// Test that we traverse those correctly.
SavedModelBundle saved_model_bundle;
GraphDef graph_def;
Scope scope = Scope::NewRootScope();
Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2});
Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output;
Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
Output c = ops::Mul(scope.WithOpName("c"), split[1], b);
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
&saved_model_bundle));
GraphDef frozen_graph_def;
std::unordered_set<string> inputs;
std::unordered_set<string> outputs;
TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
&outputs));
GraphDefEqual(frozen_graph_def, graph_def);
}
TEST_F(FreezeTest, GraphDefWithControlDependency) {
// Inputs that are control dependencies get tensor prefixes,
// i.e. ^control_dependency.
// Test that we traverse those correctly.
SavedModelBundle saved_model_bundle;
GraphDef graph_def;
Scope scope = Scope::NewRootScope();
Output source = ops::Const(scope.WithOpName("source"), 10.0f, {});
Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source),
{10.0f, 10.0f}, {2});
Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
Output c = ops::Mul(scope.WithOpName("c"), a, b);
TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
&saved_model_bundle));
GraphDef frozen_graph_def;
std::unordered_set<string> inputs;
std::unordered_set<string> outputs;
TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
&outputs));
GraphDefEqual(frozen_graph_def, graph_def);
}
TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
TestFreezeGraphWithoutDependentVariables(false);
}

View File

@ -214,7 +214,6 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@llvm//:core",
"@llvm//:execution_engine",
"@llvm//:support",
"@llvm//:target",
],

View File

@ -333,6 +333,20 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
: "";
const string include_hlo_profile_printer_data_proto =
opts.gen_hlo_profile_printer_data
? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")"
: "";
// When HLO profiling is disabled we only forward declare the
// HloProfilePrinter protobuf. So we can only conditionally emit this code
// calling HloProfilePrinter::profile_counters_size.
const string assign_profile_counters_size =
opts.gen_hlo_profile_printer_data
? "data->profile_counters_size = "
"data->hlo_profile_printer_data->profile_counters_size();"
: "";
// Use a poor-man's text templating mechanism; first populate the full header
// with placeholder tokens, and then rewrite the tokens with real values.
*header =
@ -348,6 +362,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
#define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
{{INCLUDE_XLA_DATA_PROTO}}
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
@ -418,6 +433,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
data->arg_names = StaticArgNames();
data->result_names = StaticResultNames();
data->program_shape = StaticProgramShape();
data->hlo_profile_printer_data = StaticHloProfilePrinterData();
{{ASSIGN_PROFILE_COUNTERS_SIZE}}
return data;
}();
return *kStaticData;
@ -487,6 +504,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
return kShape;
}
// Metadata that can be used to pretty-print profile counters.
static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
static const xla::HloProfilePrinterData* kHloProfilePrinterData =
{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}};
return kHloProfilePrinterData;
}
};
{{NS_END}}
@ -501,35 +525,41 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
str_util::Join(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},
{"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
{"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}",
include_hlo_profile_printer_data_proto},
{"{{METHODS_ARG}}\n", methods_arg},
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim},
{"{{RESULT_INDEX}}", strings::StrCat(result_index)},
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
{"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")},
{"{{DECLS_FROM_OBJ_FILE}}",
str_util::Join(metadata_result.header_variable_decls, "\n")},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim}};
{"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}};
str_util::ReplaceAllPairs(header, rewrites);
return Status::OK();
}
static string CreateUniqueIdentifierForProgramShape(const CodegenOpts& opts) {
static string CreateUniqueIdentifier(const CodegenOpts& opts,
StringPiece suffix) {
string result = "__tfcompile";
for (const string& n : opts.namespaces) {
strings::StrAppend(&result, "_", n);
}
strings::StrAppend(&result, "_", opts.class_name, "_ProgramShape");
strings::StrAppend(&result, "_", opts.class_name, "_", suffix);
return result;
}
@ -550,18 +580,31 @@ Status GenerateMetadata(const CodegenOpts& opts,
// When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives
// a shim that evaluates to nullptr, which is what we want.
ProtobufToEmbed program_shape_protobuf{
CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
program_shape.get()};
ProtobufToEmbed hlo_profile_printer_data_protobuf{
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
"xla::HloProfilePrinterData",
compile_result.aot->hlo_profile_printer_data()};
TF_ASSIGN_OR_RETURN(
EmbeddedProtocolBuffer embedded_program_shape,
CreateEmbeddedProtocolBuffer(opts.target_triple,
CreateUniqueIdentifierForProgramShape(opts),
"xla::ProgramShape", program_shape.get()));
EmbeddedProtocolBuffers embedded_protobufs,
CreateEmbeddedProtocolBuffers(
opts.target_triple,
{program_shape_protobuf, hlo_profile_printer_data_protobuf}));
metadata_result->program_shape_access_shim =
std::move(embedded_program_shape.cpp_shim_expression);
std::move(embedded_protobufs.cpp_shims[0].expression);
metadata_result->hlo_profile_printer_data_access_shim =
std::move(embedded_protobufs.cpp_shims[1].expression);
metadata_result->header_variable_decls.emplace_back(
std::move(embedded_program_shape.cpp_variable_decl));
std::move(embedded_protobufs.cpp_shims[0].variable_decl));
metadata_result->header_variable_decls.emplace_back(
std::move(embedded_protobufs.cpp_shims[1].variable_decl));
metadata_result->object_file_data =
std::move(embedded_program_shape.object_file_data);
std::move(embedded_protobufs.object_file_data);
return Status::OK();
}

View File

@ -44,6 +44,10 @@ struct CodegenOpts {
// If true, generate program shape data for the ProgramShape method.
bool gen_program_shape = false;
// If true, emit a serialized HloProfilePrinterData protobuf that can be used
// to pretty print HLO profile counters.
bool gen_hlo_profile_printer_data = false;
};
// Describes a generated metadata object file.
@ -57,6 +61,12 @@ struct MetadataResult {
// GenerateMetadata.
string program_shape_access_shim;
// hlo_profile_printer_data_access_shim is a C++ expression that constructs
// the xla::HloProfilePrinterData instance for the CompileResult passed to
// GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a
// C++ expression that evaluates to nullptr at runtime.
string hlo_profile_printer_data_access_shim;
// The contents of the object (".o") file.
string object_file_data;
};

View File

@ -172,7 +172,7 @@ TEST(CodegenTest, Golden) {
fetch->set_name("myfetch");
CompileResult compile_result;
compile_result.aot.reset(
new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5));
new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {}));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
{
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),

View File

@ -10,6 +10,7 @@
#define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
@ -23,6 +24,7 @@ extern "C" void entry_point(
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
namespace foo {
namespace bar {
@ -54,9 +56,9 @@ namespace bar {
//
// Memory stats:
// arg bytes total: 104
// arg bytes aligned: 128
// arg bytes aligned: 192
// temp bytes total: 126
// temp bytes aligned: 224
// temp bytes aligned: 320
class MyClass : public tensorflow::XlaCompiledCpuFunction {
public:
// Number of input arguments for the compiled computation.
@ -82,6 +84,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
data->arg_names = StaticArgNames();
data->result_names = StaticResultNames();
data->program_shape = StaticProgramShape();
data->hlo_profile_printer_data = StaticHloProfilePrinterData();
return data;
}();
return *kStaticData;
@ -243,6 +247,13 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
}();
return kShape;
}
// Metadata that can be used to pretty-print profile counters.
static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
static const xla::HloProfilePrinterData* kHloProfilePrinterData =
nullptr;
return kHloProfilePrinterData;
}
};
} // end namespace bar

View File

@ -44,7 +44,7 @@ namespace {
// Compiles the XLA computation into executable code.
Status CompileXla(xla::CompileOnlyClient* client,
const xla::Computation& computation,
const xla::XlaComputation& computation,
const xla::cpu::CpuAotCompilationOptions& aot_opts,
CompileResult* compile_result) {
// Retrieves arg and result layouts from the computation.
@ -62,7 +62,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
}
xla::CompileOnlyClient::AotComputationInstance instance;
xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result();
@ -88,20 +88,19 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
// Converts the graph into an XLA computation, and compiles the
// computation.
// TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
namespace gpu = perftools::gputools;
gpu::Platform* cpu_platform =
gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
se::Platform* cpu_platform =
se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
xla::CompileOnlyClient* client =
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
xla::Computation computation;
xla::XlaComputation computation;
TF_RETURN_IF_ERROR(
ConvertGraphDefToXla(graph_def, config, client, &computation));
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
computation.Snapshot());
// Serialize the SessionModule deterministically so that all the outputs of
// a tf_library genrule are deterministic.
// Serialize the HloSnapshot deterministically so that all the outputs of a
// tf_library genrule are deterministic.
string proto;
TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto));
TF_RETURN_IF_ERROR(
@ -111,6 +110,7 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
flags.target_triple, flags.target_cpu, flags.target_features,
flags.entry_point,
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
return CompileXla(client, computation, aot_opts, compile_result);
}

View File

@ -19,7 +19,6 @@ limitations under the License.
#include <string>
#include "llvm/ADT/Triple.h"
#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
@ -37,9 +36,8 @@ namespace tfcompile {
using xla::llvm_ir::AsStringRef;
static std::unique_ptr<llvm::Module> CreateModuleWithEmbeddedProtocolBuffer(
llvm::LLVMContext* llvm_context, llvm::TargetMachine* target_machine,
const ::tensorflow::protobuf::MessageLite& proto,
static void AddEmbeddedProtocolBufferToLlvmModule(
llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
StringPiece unique_identifier, string* protobuf_array_symbol_name,
int64* protobuf_array_size) {
string protobuf_array_contents = proto.SerializeAsString();
@ -47,19 +45,14 @@ static std::unique_ptr<llvm::Module> CreateModuleWithEmbeddedProtocolBuffer(
strings::StrCat(unique_identifier, "_protobuf_array_contents");
*protobuf_array_size = protobuf_array_contents.size();
std::unique_ptr<llvm::Module> module =
MakeUnique<llvm::Module>("embedded_data_module", *llvm_context);
llvm::Constant* protobuf_array_initializer =
llvm::ConstantDataArray::getString(*llvm_context,
llvm::ConstantDataArray::getString(module->getContext(),
AsStringRef(protobuf_array_contents),
/*AddNull=*/false);
new llvm::GlobalVariable(
*module, protobuf_array_initializer->getType(),
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage,
protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
return module;
}
static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
@ -116,42 +109,44 @@ GetTargetMachineFromTriple(StringPiece target_triple) {
/*Features=*/"", llvm::TargetOptions(), llvm::None));
}
StatusOr<EmbeddedProtocolBuffer> CreateEmbeddedProtocolBuffer(
StringPiece target_triple, StringPiece symbol_prefix,
StringPiece qualified_cpp_protobuf_name,
const ::tensorflow::protobuf::MessageLite* proto) {
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
StringPiece target_triple,
gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
GetTargetMachineFromTriple(target_triple));
llvm::LLVMContext llvm_context;
string object_file, cpp_shim, cpp_variable_decl;
std::unique_ptr<llvm::Module> module_with_serialized_proto =
MakeUnique<llvm::Module>("embedded_data_module", llvm_context);
if (proto) {
string protobuf_array_symbol_name;
int64 protobuf_array_size;
EmbeddedProtocolBuffers result;
std::unique_ptr<llvm::Module> module_with_serialized_proto =
CreateModuleWithEmbeddedProtocolBuffer(
&llvm_context, target_machine.get(), *proto, symbol_prefix,
&protobuf_array_symbol_name, &protobuf_array_size);
TF_ASSIGN_OR_RETURN(object_file,
CodegenModule(target_machine.get(),
std::move(module_with_serialized_proto)));
cpp_shim = CreateCPPShimExpression(qualified_cpp_protobuf_name,
protobuf_array_symbol_name,
protobuf_array_size);
for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) {
string cpp_shim, cpp_variable_decl;
if (protobuf_to_embed.message) {
string protobuf_array_symbol_name;
int64 protobuf_array_size;
cpp_variable_decl = strings::StrCat("extern \"C\" char ",
protobuf_array_symbol_name, "[];");
} else {
TF_ASSIGN_OR_RETURN(
object_file,
CodegenModule(target_machine.get(),
MakeUnique<llvm::Module>("empty_module", llvm_context)));
cpp_shim = "nullptr";
AddEmbeddedProtocolBufferToLlvmModule(
module_with_serialized_proto.get(), *protobuf_to_embed.message,
protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name,
&protobuf_array_size);
cpp_shim = CreateCPPShimExpression(
protobuf_to_embed.qualified_cpp_protobuf_name,
protobuf_array_symbol_name, protobuf_array_size);
cpp_variable_decl = strings::StrCat("extern \"C\" char ",
protobuf_array_symbol_name, "[];");
} else {
cpp_shim = "nullptr";
}
result.cpp_shims.push_back({cpp_shim, cpp_variable_decl});
}
return {{cpp_shim, cpp_variable_decl, object_file}};
TF_ASSIGN_OR_RETURN(result.object_file_data,
CodegenModule(target_machine.get(),
std::move(module_with_serialized_proto)));
return result;
}
} // namespace tfcompile

View File

@ -21,51 +21,70 @@ limitations under the License.
#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
namespace tfcompile {
using xla::StatusOr;
// Represents a protocol buffer embedded into an object file and describes a way
// to access it at runtime.
struct EmbeddedProtocolBuffer {
// cpp_shim_expression is a C++ expression that creates an instance of said
// protocol buffer when executed.
string cpp_shim_expression;
// Represents a set of protocol buffers embedded into an object file and
// describes how to access them at runtime.
struct EmbeddedProtocolBuffers {
// Each instance CPPShim describes how to generate C++ code to instantiate a
// protobuf instance from the corresponding static data emitted into the
// object file.
struct CPPShim {
// `expression` is a C++ expression that creates an instance of said
// protocol buffer when executed.
string expression;
// cpp_variable_decl is an "extern C" array declaration that is used in
// cpp_shim_expression. It must be visible wherever cpp_shim_expression is
// emitted.
string cpp_variable_decl;
// `variable_decl` is an "extern C" array declaration that is used in
// `expression`. It must be visible wherever `expression` is emitted.
string variable_decl;
};
// The contents of the object (".o") file the protocol buffer is embbed in.
// This needs to be linked in to any program that wants to execute
// cpp_variable_decl .
// Each cpp_shim corresponds to one embedded protocol buffer.
std::vector<CPPShim> cpp_shims;
// The contents of the object (".o") file the protocol buffers are embbed in.
// This needs to be linked in to any program that wants to execute any of the
// expressions in `cpp_shims`.
string object_file_data;
};
// Creates an object file that contains `proto`.
//
// `proto` is allowed to be nullptr, in which case the generated C++ shim
// expression is just `nullptr`, and the generated object file does not define
// any symbols.
// Describes a protocol buffer to embed into an object file.
struct ProtobufToEmbed {
// `symbol_prefix` is prefix that is guaranteed to be unique across the binary
// or DSO the generated object file will be linked into.
string symbol_prefix;
// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++
// namespace qualified) protocol buffer name. This is only used in
// CPPShim::expression so relatively qualified names are fine as long as
// they're valid wherever CPPShim::expression is emitted.
string qualified_cpp_protobuf_name;
// `message` is the protocol buffer to be embedded. It is allowed to be
// nullptr, in which case the generated C++ shim expression is just `nullptr`,
// and the generated object file does not define any symbols.
const ::tensorflow::protobuf::MessageLite* message;
};
// Embeds a a sequence of protocol buffers into an object file.
//
// `target_triple` is the target triple for the target architecture for the
// generated object file.
//
// `symbol_prefix` is prefix that is guaranteed to be unique across the binary
// or DSO the generated object file will be linked into.
//
// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++
// namespace qualified) protocol buffer name. This needs is only used in
// EmbeddedProtocolBuffer::cpp_shim_expression so relatively qualified
// names are fine as long as they're valid wherever cpp_shim_expression
// is emitted.
StatusOr<EmbeddedProtocolBuffer> CreateEmbeddedProtocolBuffer(
StringPiece target_triple, StringPiece symbol_prefix,
StringPiece qualified_cpp_protobuf_name,
const ::tensorflow::protobuf::MessageLite* proto);
// `protobufs_to_embed` describes the protocol buffers to embed into the
// resulting object file. The C++ shim for protobufs_to_embed[i] is
// cpp_shims[i] in the returned EmbeddedProtocolBuffers instance. The contents
// of all the protocol buffers are embedded into a single .o file whose content
// is stored in the object_file_data field in the returned
// EmbeddedProtocolBuffers instance.
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
StringPiece target_triple,
gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed);
} // namespace tfcompile
} // namespace tensorflow

View File

@ -25,8 +25,8 @@ namespace tensorflow {
namespace tfcompile {
namespace runtime {
// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
static constexpr size_t kAlign = 32;
// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
static constexpr size_t kAlign = 64;
// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1
// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign

View File

@ -24,7 +24,7 @@ namespace runtime {
namespace {
TEST(Runtime, AlignmentValue) {
// We've chosen 32 byte alignment for the tfcompile runtime to mimic the
// We've chosen 64 byte alignment for the tfcompile runtime to mimic the
// regular tensorflow allocator, which was chosen to play nicely with Eigen.
// The tfcompile runtime also has a requirement that comes from the xla
// generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8
@ -39,13 +39,13 @@ TEST(Runtime, AlignedBufferBytes) {
EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0);
static constexpr intptr_t sizesB[1] = {3};
EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32);
EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64);
static constexpr intptr_t sizesC[1] = {32};
EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32);
EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64);
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192);
EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320);
}
void* add_ptr(void* base, uintptr_t delta) {
@ -101,11 +101,11 @@ TEST(Runtime, MallocFreeContiguousBuffers) {
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufD[0], add_ptr(base, 0));
EXPECT_EQ(bufD[1], nullptr);
EXPECT_EQ(bufD[2], add_ptr(base, 32));
EXPECT_EQ(bufD[2], add_ptr(base, 64));
EXPECT_EQ(bufD[3], nullptr);
EXPECT_EQ(bufD[4], add_ptr(base, 64));
EXPECT_EQ(bufD[5], add_ptr(base, 128));
EXPECT_EQ(bufD[6], add_ptr(base, 160));
EXPECT_EQ(bufD[4], add_ptr(base, 128));
EXPECT_EQ(bufD[5], add_ptr(base, 192));
EXPECT_EQ(bufD[6], add_ptr(base, 256));
for (int i = 0; i < 7; ++i) {
const intptr_t size = sizesD[i];
if (size != -1) {

View File

@ -35,6 +35,7 @@ limitations under the License.
// clang-format on
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"

View File

@ -15,6 +15,7 @@ test_suite(
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
":test_graph_tfassert_eq_test",
":test_graph_tfcond_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
@ -55,6 +56,7 @@ genrule(
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
"test_graph_tfassert_eq.pb",
"test_graph_tfcond.pb",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
@ -118,6 +120,17 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfcond",
testonly = 1,
config = "test_graph_tfcond.config.pbtxt",
cpp_class = "CondComp",
graph = "test_graph_tfcond.pb",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tffunction",
testonly = 1,
@ -163,6 +176,15 @@ tf_library(
tfcompile_flags = "--gen_name_to_index --gen_program_shape",
)
tf_library(
name = "test_graph_tfmatmulandadd_with_profiling",
testonly = 1,
config = "test_graph_tfmatmulandadd.config.pbtxt",
cpp_class = "MatMulAndAddCompWithProfiling",
enable_xla_hlo_profiling = True,
graph = "test_graph_tfmatmulandadd.pb",
)
tf_library(
name = "test_graph_tfsplits",
testonly = 1,
@ -185,13 +207,18 @@ tf_cc_test(
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tfassert_eq",
":test_graph_tfcond",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
":test_graph_tfmatmulandadd",
":test_graph_tfmatmulandadd_with_profiling",
":test_graph_tfsplits",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_profile_printer",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",

View File

@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir):
f.write(saver.as_saver_def().SerializeToString())
def tfassert_eq(_):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
control_flow_ops.Assert(
math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
math_ops.add(x, math_ops.negative(y), name='x_y_diff')
def tfcond(_):
p = array_ops.placeholder(dtypes.bool, name='p_hold')
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
z = control_flow_ops.cond(p, lambda: x, lambda: y)
array_ops.identity(z, name='result')
def tfgather(_):
params = array_ops.placeholder(dtypes.float32, name='params')
indices = array_ops.placeholder(dtypes.int32, name='indices')
@ -126,14 +142,6 @@ def tfsplits(_):
array_ops.identity(y, name='result')
def tfassert_eq(_):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = array_ops.placeholder(dtypes.int32, name='y_hold')
control_flow_ops.Assert(
math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
math_ops.add(x, math_ops.negative(y), name='x_y_diff')
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@ -148,12 +156,13 @@ def main(_):
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)
write_graph(tfcond, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)
if __name__ == '__main__':

View File

@ -0,0 +1,20 @@
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "p_hold" }
shape {}
}
feed {
id { node_name: "x_hold" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_hold" }
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "result" }
}

View File

@ -21,19 +21,27 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace {
using ::testing::HasSubstr;
using ::testing::IsSupersetOf;
TEST(TFCompileTest, Add) {
AddComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]);
@ -143,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
TEST(TFCompileTest, Cond) {
CondComp cond;
EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
cond.arg1() = 10;
cond.arg2() = 20;
{
cond.arg0() = true;
const int32 expected_result = cond.arg1();
EXPECT_TRUE(cond.Run());
EXPECT_EQ(cond.result0(), expected_result);
EXPECT_EQ(cond.result0_data()[0], expected_result);
EXPECT_EQ(cond.result0_data(), cond.results()[0]);
}
{
cond.arg0() = false;
const int32 expected_result = cond.arg2();
EXPECT_TRUE(cond.Run());
EXPECT_EQ(cond.result0(), expected_result);
EXPECT_EQ(cond.result0_data()[0], expected_result);
EXPECT_EQ(cond.result0_data(), cond.results()[0]);
}
}
TEST(TFCompileTest, Gather) {
GatherComp gather;
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
@ -484,6 +517,56 @@ TEST(TFCompileTest, ProgramShape) {
EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2));
}
TEST(TFCompileTest, HloProfiling) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
MatMulAndAddCompWithProfiling fn;
ASSERT_TRUE(fn.hlo_profiling_enabled());
fn.set_thread_pool(&device);
// x = [[1, 2], [3, 4]]
fn.arg0(0, 0) = 1;
fn.arg0(0, 1) = 2;
fn.arg0(1, 0) = 3;
fn.arg0(1, 1) = 4;
// y = [[10, 20], [30, 40]]
fn.arg1(0, 0) = 10;
fn.arg1(0, 1) = 20;
fn.arg1(1, 0) = 30;
fn.arg1(1, 1) = 40;
EXPECT_TRUE(fn.Run());
string hlo_profile_as_string =
xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
/*clock_rate_ghz=*/1.0);
VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
std::vector<string> hlo_profile_lines =
tensorflow::str_util::Split(hlo_profile_as_string, '\n');
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
"%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
auto add_profile_line = HasSubstr(
"%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
auto tuple_profile_line = HasSubstr(
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
"%dot.0.4, f32[2,2]{1,0} %add.0.6)");
auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
EXPECT_THAT(hlo_profile_lines,
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
add_profile_line, tuple_profile_line}));
}
} // namespace
} // namespace tfcompile
} // namespace tensorflow

View File

@ -25,7 +25,8 @@ def tf_library(name, graph, config,
visibility=None, testonly=None,
tfcompile_flags=None,
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
include_standard_runtime_deps=True, deps=None, tags=None):
include_standard_runtime_deps=True,
enable_xla_hlo_profiling=False, deps=None, tags=None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
Given an invocation of tf_library(name="foo", ...), generates the following
@ -68,6 +69,8 @@ def tf_library(name, graph, config,
include_standard_runtime_deps: If True, the standard list of kernel/runtime
deps is added to deps. If False, deps must contain the full set of deps
needed by the generated library.
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
and emit metadata that lets us pretty-print the gathered profile counters.
deps: a list of deps to include on the build rules for the generated
library, added to the standard deps if standard_runtime_deps is True.
tags: tags to apply to subsidiary build rules.
@ -137,6 +140,10 @@ def tf_library(name, graph, config,
flags = tfcompile_flags
else:
flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
else:
profiling_flag = ""
native.genrule(
name=("gen_" + name),
srcs=[
@ -157,7 +164,7 @@ def tf_library(name, graph, config,
" --out_header=$(@D)/" + header_file +
" --out_metadata_object=$(@D)/" + metadata_object_file +
" --out_function_object=$(@D)/" + function_object_file +
" " + flags),
" " + flags + " " + profiling_flag),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
@ -220,6 +227,8 @@ def tf_library(name, graph, config,
] + (need_xla_data_proto and [
# If we're generating the program shape, we must depend on the proto.
"//tensorflow/compiler/xla:xla_data_proto",
] or []) + (enable_xla_hlo_profiling and [
"//tensorflow/compiler/xla/service:hlo_profile_printer_data"
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",

View File

@ -100,6 +100,8 @@ Status Main(const MainFlags& flags) {
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));

View File

@ -124,7 +124,6 @@ cc_library(
srcs = ["xla_tensor.cc"],
hdrs = ["xla_tensor.h"],
deps = [
":common",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:shaped_buffer",
@ -176,6 +175,7 @@ cc_library(
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:sendrecv_ops",
@ -216,6 +216,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
@ -256,23 +257,11 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "graph_to_functiondef",
srcs = ["graph_to_functiondef.cc"],
hdrs = ["graph_to_functiondef.h"],
visibility = [":friends"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "create_xla_launch_op",
srcs = [
"create_xla_launch_op.cc",
"create_xla_launch_op.h",
],
deps = [
":common",
@ -282,6 +271,27 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
tf_cc_test(
name = "create_xla_launch_op_test",
srcs = [
"create_xla_launch_op.h",
"create_xla_launch_op_test.cc",
],
deps = [
":create_xla_launch_op",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
@ -299,7 +309,6 @@ cc_library(
],
deps = [
":common",
":graph_to_functiondef",
":shape_inference_helpers",
":union_find",
"//tensorflow/compiler/jit/graphcycles",
@ -318,6 +327,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
],
)
@ -345,28 +355,6 @@ tf_cc_test(
],
)
tf_cc_test(
name = "graph_to_functiondef_test",
size = "small",
srcs = [
"graph_to_functiondef_test.cc",
],
deps = [
":graph_to_functiondef",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test(
name = "compilation_passes_test",
size = "small",
@ -377,7 +365,6 @@ tf_cc_test(
deps = [
":common",
":compilation_passes",
":graph_to_functiondef",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
@ -395,6 +382,31 @@ tf_cc_test(
],
)
tf_cc_test(
name = "xla_launch_util_test",
size = "small",
srcs = ["xla_launch_util_test.cc"],
deps = [
":common",
":xla_compilation_cache",
":xla_launch_util",
":xla_tensor",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/kernels:variable_ops",
],
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",

View File

@ -40,7 +40,7 @@ static Status BuildLaunchNode(
Graph* graph, Node** node) {
NodeDef def;
def.set_name(graph->NewName(nodename));
def.set_op("_XlaLaunch");
def.set_op("XlaLaunch");
def.set_device(device_name);
AddNodeAttr("Tconstants", constant_dtypes, &def);
AddNodeAttr("Targs", arg_dtypes, &def);
@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
node->input_types().begin() + num_constant_args,
node->input_types().begin() + num_constant_args + num_nonconst_args);
// Build a _XlaLaunch operator to execute the function body.
// Build a XlaLaunch operator to execute the function body.
Node* launch_node;
TF_RETURN_IF_ERROR(BuildLaunchNode(
graph->NewName(node->name()), node->type_string(), node->def().attr(),

View File

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/create_xla_launch_op.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
@ -21,82 +22,194 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
// 'ndef' is a call to a compilable function defined in 'flr', returns OK
// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
// node. Otherwise, returns a non-OK.
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// This routine is here so that FunctionLibraryRuntime can jit a
// specific function call as requested.
Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
std::unique_ptr<OpKernel>* kernel) {
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
public:
// Creates a SinglePassSearch object that can be used to search in `values`.
// Does not take ownership of `values`. `values` must outlive this.
// `values` must be sorted.
explicit SinglePassSearch(const std::vector<int>* values)
: current_index_(0), values_(values) {}
// Scans forward in the vector looking for "value", updating the internal
// position in to the vector.
// Returns true iff the vector contains the given value at or after current
// position.
// Not thread-safe.
bool ScanForValue(int value) {
while (current_index_ < values_->size() &&
(*values_)[current_index_] <= value) {
if ((*values_)[current_index_] == value) {
current_index_++;
return true;
}
current_index_++;
}
return false;
}
private:
int current_index_;
const std::vector<int>* values_;
};
Status CompilationRequested(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) {
bool xla_compile = false;
if (!flr->GetFunctionLibraryDefinition()
->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
.ok() ||
!xla_compile) {
// Not marked as _XlaCompile=true.
return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
}
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
if (!IsCompilable(flr, ndef)) {
// ndef is calling a function that XLA can't compile.
return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
// Check if op is marked _XlaCompile=true.
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
node_def, kXlaCompileAttr, &xla_compile);
if (!status.ok() || !xla_compile) {
if (VLOG_IS_ON(3)) {
if (!status.ok()) {
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
<< node_def.op() << ". status=" << status.ToString();
} else {
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
}
}
return Status(error::INVALID_ARGUMENT, "");
}
return Status::OK();
}
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
// runtime, returns this function's body in `fbody` as well as the indices
// of its constant and resource arguments.
// `fbody` is owned by `flr`.
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
// They are sorted in ascending order on this function's return.
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
const FunctionBody** fbody,
std::vector<int>* constant_arg_indices,
std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
// If ndef is not instantiable, e.g., the function does not exist,
// If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(
flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
const FunctionBody* fbody = flr->GetFunctionBody(handle);
CHECK(fbody); // Can't be nullptr since we just instantiated it.
std::vector<bool> const_args(fbody->arg_types.size());
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
*fbody = flr->GetFunctionBody(handle);
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
const DataTypeVector& arg_types = (*fbody)->arg_types;
std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
// There is a const arg. Bail out.
return errors::InvalidArgument("Const arg: ", i, " in ",
DebugString(fbody->fdef));
constant_arg_indices->push_back(i);
}
}
NodeDef launch_def;
launch_def.set_name(ndef.name());
launch_def.set_op("_XlaLaunch");
launch_def.set_device(flr->device()->name());
AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
AddNodeAttr("Nresources", 0, &launch_def);
AddNodeAttr("Targs", fbody->arg_types, &launch_def);
AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
NameAttrList func;
func.set_name(ndef.op());
*(func.mutable_attr()) = ndef.attr();
AddNodeAttr("function", func, &launch_def);
// There can be hundreds of resource variables. Reserve the space for them.
// We don't reserve for constants above as they are usually few.
resource_arg_indices->reserve(arg_types.size());
for (int i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] == DT_RESOURCE) {
resource_arg_indices->push_back(i);
}
}
// TODO(b/32387911): Handles the host memory types across function
// calls properly. For now, we assume all inputs and outputs are on
// the device memory.
return Status::OK();
}
} // namespace
Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
if (!IsCompilable(flr, node_def)) {
// node_def is calling a function that XLA can't compile.
return errors::InvalidArgument("Not compilable: ",
node_def.ShortDebugString());
}
// Get function body, constant args, and resource args.
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
// Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
// These indices are used only for optimization purposes. They allow us
// to loop over constant_arg_indices and resource_arg_indices only once
// while iterating over all the function arguments checking if it is a
// resource or a constant.
// The reason we optimized this code is because functions can have a lot of
// captured arguments. For example, the backward pass of ResNet50 takes in all
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (int i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
input_memory_types[i] = HOST_MEMORY;
}
}
// One might wonder, about the case where a compile-time constant argument
// (which must be in host memory) is also used as an input into an op,
// e.g. Add, that expects its inputs in device memory. Here is how it
// works now.
// First, what do we mean by "op expects an input in XYZ memory"?
// There are two types of "ops" here: the tf2xla kernel and the HLO
// computation it builds. The tf2xla kernel needs to retrieve the actual
// numeric value of the compile-time constant tensors, so it really expects
// them to be on in host memory. However, for other inputs, it refers to them
// using xla::ComputationDataHandle, which is just a symbolic handle that
// xla::ComputationBuilder assigns. How does this handle gets assigned for
// constant arguments? Even constant arguments get an _Arg node in the graph
// instatiated for Function compilation. The tf2xla kernel for constant _Arg
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
// constant XlaLiteral is included in the HLO graph, and subsequently, in
// the actual executable, which is copied to the device before being
// executed. Thus, when this executable runs, the constant is available in
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
// Create the kernel.
NameAttrList function;
function.set_name(node_def.op());
*(function.mutable_attr()) = node_def.attr();
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &launch_def,
dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
kernel->reset(new XlaLocalLaunchOp(&construction));
*kernel = MakeUnique<XlaLocalLaunchBase>(&construction, constant_arg_indices,
resource_arg_indices, function);
return s;
}
namespace {
bool RegisterLaunchOpCreator() {
RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
return true;

View File

@ -0,0 +1,35 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class FunctionLibraryRuntime;
class OpKernel;
// Given a NodeDef 'node_def' and the function library runtime 'flr', if
// 'node_def' is a call to a compilable function defined in 'flr', returns OK
// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
// node. Otherwise, returns a non-OK.
Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_

View File

@ -0,0 +1,145 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/create_xla_launch_op.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
NodeDef ToNodeDef(const string& text) {
NodeDef node_def;
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
return node_def;
}
// Create a FunctionDef that takes one resource and one regular param
FunctionDef XTimesY() {
return FunctionDefHelper::Define(
// Name
"XTimesY",
// Args
{"x: float", "y: resource"},
// Return values
{"z: float"},
// Attr def
{},
// Nodes
{
{{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}},
{{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}},
});
}
class CreateXlaLaunchOpTest : public ::testing::Test {
protected:
void Init(const std::vector<FunctionDef>& flib) {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1});
TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices_));
FunctionDefLibrary proto;
for (const auto& fdef : flib) {
*(proto.add_function()) = fdef;
}
lib_def_ =
MakeUnique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
OptimizerOptions opts;
device_mgr_ = MakeUnique<DeviceMgr>(devices_);
pflr_ = MakeUnique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
}
FunctionLibraryRuntime* flr_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
std::unique_ptr<OpKernel> kernel_;
};
AttrValue BoolAttr(bool b) {
AttrValue v;
v.set_b(b);
return v;
}
TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) {
FunctionDef fdef = XTimesY();
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
Init({fdef});
Status status = CreateXlaLaunchOp(
flr_, ToNodeDef(R"pb(
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
)pb"), &kernel_);
ASSERT_TRUE(status.ok()) << status.ToString();
EXPECT_EQ("XTimesY", kernel_->name());
EXPECT_EQ("XTimesY", kernel_->type_string());
EXPECT_EQ(2, kernel_->num_inputs());
EXPECT_EQ(DT_FLOAT, kernel_->input_type(0));
EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1));
EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]);
EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]);
EXPECT_EQ(1, kernel_->num_outputs());
EXPECT_EQ(DT_FLOAT, kernel_->output_type(0));
EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]);
}
TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) {
FunctionDef fdef = XTimesY();
Init({fdef});
Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"), &kernel_);
EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
}
TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) {
FunctionDef fdef = XTimesY();
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
Init({fdef});
Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"), &kernel_);
EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString();
}
} // namespace tensorflow

View File

@ -22,7 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
@ -160,6 +161,11 @@ class Encapsulator {
std::move(outside_compilation_attribute)),
graph_in_(graph_in) {}
// Find dependencies between subgraphs and outside_compilation clusters that
// only manifest via edges between outside_compilation clusters in the outer
// (non-compiled) graph.
Status FindClusterDependencies();
// Find subgraphs marked with 'group_attribute', and build a new
// subgraph, one for each value of 'group_attribute'.
Status SplitIntoSubgraphs();
@ -230,6 +236,19 @@ class Encapsulator {
// the shapes of any ancestor RAH outputs. If it can be determined that the
// shape of the SFH inputs will not be inferrable even once the shapes of the
// RAH outputs are known, an error is returned by the rewriter.
//
// Once edges between compiled and outside_compilation clusters have been
// replaced by send/recv ops, some dependencies may no longer be apparent.
// A clustering pass finds all the dependencies between HC nodes that are only
// present as a result of edges between nodes in outside_compilation clusters.
// Suppose there is a path from outside_compilation cluster C in subgraph S
// to outside_compilation cluster D in subgraph T. If S != T then a control
// edge is added from the call node for S to the call node for T, which
// ensures that C will execute before D because S executes before T. If S==T
// then a control dependency is added between the HC nodes for C and D in S,
// and the HC node for C is added to an 'ancestors' attr in the HC node for D
// so that during compilation of the HC node for D, an XLA control dependency
// can be added to ensure C's SendToHost executes before D's RecvFromHost.
class Subgraph {
public:
// Creates a graph to build the subgraph in, if it doesn't already exist,
@ -324,6 +343,18 @@ class Encapsulator {
void RecordOutsideCompilationOutputOrControl(
const string& outside_compilation_id, const Edge* edge);
// Records the fact that there is a path from a node in outside_compilation
// cluster ancestor to node in cluster successor that does not go through
// the subgraph.
void RecordOutsideCompilationDependency(const string& successor,
const string& ancestor);
// Returns the mapping from outside_compilation cluster C to the set of
// outside_compilation clusters that have a path to C entirely outside
// compiled subgraphs.
const std::unordered_map<string, std::unordered_set<string>>
OutsideCompilationAncestorMap() const;
// Adds the HostCompute nodes for each outside_compilation subgraph.
Status AddHostComputes(
const string& subgraph_name,
@ -406,6 +437,13 @@ class Encapsulator {
Status AddHostComputeKeyPlaceholder(OutsideCompilationSubgraph* oc_subgraph,
Graph* graph_out);
// Get the set of outside_compilation clusters and the dependency edges
// between them.
void GetActiveClusterDependencyGraph(
std::unordered_set<string>* clusters,
std::unordered_set<string>* has_successor,
std::unordered_map<string, std::unordered_set<string>>* ancestors_map);
// Builds a _RecvAtHost node producing all the inputs of an
// outside_compilation subgraph and stores it in oc_subgraph.recv_at_host.
Status AddRecvAtHostNode(const string& group_attribute,
@ -468,6 +506,14 @@ class Encapsulator {
// The outside_compilation clusters in this subgraph.
std::unordered_map<string, OutsideCompilationSubgraph>
outside_compilation_subgraphs_;
// For each outside_compilation cluster C, the outside_compilation clusters
// that have a path to C outside the compiled graph.
std::unordered_map<string, std::unordered_set<string>>
outside_compilation_ancestors_;
// For each outside_compilation cluster C, the outside_compilation clusters
// that have a path from C outside the compiled graph.
std::unordered_map<string, std::unordered_set<string>>
outside_compilation_successors_;
// NoOp node in the output graph that is sequenced after the call node and
// used to prevent host-side outside_compilation sends and recvs from being
@ -556,6 +602,10 @@ class Encapsulator {
std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
edges_added);
// Adds control dependencies between subgraph call nodes that have
// dependencies via outside_compilation edges.
Status AddCallNodeDependencies(Graph* graph_out);
// Adds all edges to the output graph.
Status AddEdgesToOutputGraph(
const std::unordered_map<const Node*, Node*>& node_images,
@ -620,10 +670,65 @@ class Encapsulator {
const Graph* graph_in_;
std::unordered_map<string, Subgraph> subgraphs_;
// For each subgraph S the subgraphs S' such that there is a path in some
// outside_compilation cluster C in S to some outside_compilation cluster C'
// in S', that goes only through the uncompiled graph.
std::unordered_map<string, std::unordered_set<string>> subgraph_ancestors_;
TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
};
namespace {
// Return in 'sorted' a topological sort of clusters according to the
// dependencies encoded in ancestors. clusters is the list of all clusters
// including clusters that are not present in the ancestors map. has_successors
// is the set of clusters that are ancestors of some other cluster.
void TopologicalClusterSort(
const std::unordered_set<string>& clusters,
const std::unordered_set<string>& has_successors,
const std::unordered_map<string, std::unordered_set<string>>& ancestors,
std::vector<string>* sorted) {
// The nodes are placed in 'sorted' in topological order.
sorted->clear();
// We don't use the standard DFS because we are not operating on Node*
// objects.
struct Work {
string cluster;
bool leave;
};
std::set<string> visited;
std::vector<Work> stack;
// Seed the processing list with clusters that have no successors.
for (const auto& cluster : clusters) {
if (has_successors.find(cluster) == has_successors.end()) {
stack.push_back({cluster, false});
}
}
while (!stack.empty()) {
const Work item = stack.back();
stack.pop_back();
if (item.leave) {
sorted->push_back(item.cluster);
continue;
}
if (visited.find(item.cluster) != visited.end()) continue;
visited.insert(item.cluster);
stack.push_back({item.cluster, true});
const auto& iter = ancestors.find(item.cluster);
if (iter != ancestors.end()) {
for (const auto& ancestor : iter->second) {
stack.push_back({ancestor, false});
}
}
}
CHECK(sorted->size() == clusters.size());
}
} // namespace
Node* Encapsulator::Subgraph::GetCallNodeForInputs() const {
return call_node_inputs_;
}
@ -786,12 +891,71 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
}
}
void Encapsulator::Subgraph::RecordOutsideCompilationDependency(
const string& successor, const string& ancestor) {
outside_compilation_ancestors_[successor].insert(ancestor);
outside_compilation_successors_[ancestor].insert(successor);
}
const std::unordered_map<string, std::unordered_set<string>>
Encapsulator::Subgraph::OutsideCompilationAncestorMap() const {
return outside_compilation_ancestors_;
}
void Encapsulator::Subgraph::GetActiveClusterDependencyGraph(
std::unordered_set<string>* clusters,
std::unordered_set<string>* has_successor,
std::unordered_map<string, std::unordered_set<string>>* ancestors_map) {
// During initial clustering the ancestor and successor datastructures may
// have been built including oc_cluster names that never turned into subgraphs
// because they had no edges into or out of the compiled cluster. Remove them
// before proceeding to simplify the logic. Get the set of clusters that was
// actually added, then remove references to the others.
for (const auto& oc_subgraph : outside_compilation_subgraphs_) {
clusters->insert(oc_subgraph.first);
}
for (const auto& cluster : outside_compilation_successors_) {
if (clusters->find(cluster.first) != clusters->end()) {
for (const auto& successor : cluster.second) {
if (clusters->find(successor) != clusters->end()) {
has_successor->insert(cluster.first);
break;
}
}
}
}
for (const auto& cluster : outside_compilation_ancestors_) {
if (clusters->find(cluster.first) != clusters->end()) {
std::unordered_set<string>& ancestors = (*ancestors_map)[cluster.first];
for (const auto& ancestor : cluster.second) {
if (clusters->find(ancestor) != clusters->end()) {
ancestors.insert(ancestor);
}
}
}
}
}
Status Encapsulator::Subgraph::AddHostComputes(
const string& subgraph_name,
const std::unordered_map<const Node*, Node*>& node_images) {
for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
const string& oc_subgraph_name = oc_subgraph_iter.first;
OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
// Get the set of outside_compilation clusters and the dependency edges
// between them.
std::unordered_set<string> clusters;
std::unordered_set<string> has_successor;
std::unordered_map<string, std::unordered_set<string>> ancestors_map;
GetActiveClusterDependencyGraph(&clusters, &has_successor, &ancestors_map);
// Topologically sort the outside_compilation clusters according to their
// dependency relation.
std::vector<string> sorted_clusters;
TopologicalClusterSort(clusters, has_successor, ancestors_map,
&sorted_clusters);
// The host compute nodes added for each outside_compilation_cluster;
std::unordered_map<string, Node*> host_compute_node;
for (const string& oc_subgraph_name : sorted_clusters) {
OutsideCompilationSubgraph& oc_subgraph =
outside_compilation_subgraphs_[oc_subgraph_name];
if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() ||
!oc_subgraph.outputs_by_src.empty() ||
!oc_subgraph.control_outputs.empty()) {
@ -811,13 +975,22 @@ Status Encapsulator::Subgraph::AddHostComputes(
inputs[input_index].Reset(src_image->name(), src_slot, dtype);
input_dtypes[input_index] = dtype;
}
for (const auto& output : oc_subgraph.outputs_by_src) {
DataType dtype = output.first.dtype;
int output_index = output.second;
output_dtypes[output_index] = dtype;
}
std::vector<string> host_compute_ancestors;
const auto iter = ancestors_map.find(oc_subgraph_name);
if (iter != ancestors_map.end()) {
for (const string& ancestor_cluster : iter->second) {
host_compute_ancestors.push_back(
outside_compilation_subgraphs_[ancestor_cluster]
.host_compute_name);
}
}
NodeDef host_compute_def;
NodeDefBuilder builder(strings::StrCat("outside_compilation_",
oc_subgraph_name, "_host_compute"),
@ -825,6 +998,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
builder.Input(inputs);
builder.Attr("Tinputs", input_dtypes);
builder.Attr("Toutputs", output_dtypes);
builder.Attr("ancestors", host_compute_ancestors);
builder.Attr("key",
strings::StrCat("host_compute_channel_", subgraph_name, "_",
oc_subgraph_name));
@ -834,6 +1008,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
Node* host_compute = graph_->AddNode(host_compute_def, &s);
if (!s.ok()) return s;
host_compute_node[host_compute->name()] = host_compute;
oc_subgraph.host_compute_name = host_compute->name();
// Connect the _HostCompute node to its producers in the subgraph.
@ -852,6 +1027,12 @@ Status Encapsulator::Subgraph::AddHostComputes(
graph_->AddControlEdge(src_image, host_compute);
}
// Connect the _HostCompute node to its ancestor host compute nodes.
for (const auto& ancestor_name : host_compute_ancestors) {
Node* ancestor = host_compute_node[ancestor_name];
graph_->AddControlEdge(ancestor, host_compute);
}
// Connect the consumers in the subgraph to the _HostCompute node.
for (const auto& output : oc_subgraph.outputs_by_dst) {
const Node* dst_node = output.first.node;
@ -1654,6 +1835,17 @@ Status Encapsulator::CopyEdgeToOutputGraph(
return Status::OK();
}
Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) {
for (const auto& ancestors : subgraph_ancestors_) {
const string& subgraph = ancestors.first;
for (const string& ancestor : ancestors.second) {
graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNodeForOutputs(),
subgraphs_[subgraph].GetCallNodeForInputs());
}
}
return Status::OK();
}
Status Encapsulator::AddEdgesToOutputGraph(
const std::unordered_map<const Node*, Node*>& node_images,
bool parallel_checking, Graph* graph_out) {
@ -1703,6 +1895,7 @@ Status Encapsulator::AddEdgesToOutputGraph(
Subgraph& subgraph = subgraph_entry.second;
subgraph.ConnectSequencerToCallNode(graph_out);
}
TF_RETURN_IF_ERROR(AddCallNodeDependencies(graph_out));
return Status::OK();
}
@ -1960,6 +2153,182 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
return Status::OK();
}
namespace {
// Helper struct for building cluster dependencies and also debugging cycles in
// the dependencies. While computing dependencies we construct a mapping from
// Node* to PathDetails.
struct PathDetails {
struct SubgraphAndCluster {
string subgraph;
string outside_compilation_cluster;
bool operator==(const SubgraphAndCluster& other) const {
return subgraph == other.subgraph &&
outside_compilation_cluster == other.outside_compilation_cluster;
}
};
struct SubgraphAndClusterHash {
inline std::size_t operator()(const SubgraphAndCluster& v) const {
return hash<string>()(
strings::StrCat(v.subgraph, v.outside_compilation_cluster));
}
};
typedef std::unordered_set<SubgraphAndCluster, SubgraphAndClusterHash>
SubgraphAndClusterSet;
// Returns the set of (subgraph, oc_cluster) pairs that should be recorded as
// ancestors for any successor of this node. If the node is in the outer
// graph, it returns the transitive union of the ancestors of the node's
// inputs. If the node is in an outside_compilation cluster, it returns just
// that cluster. If the node is compiled, it returns the empty set.
SubgraphAndClusterSet AncestorsForSuccessor() {
if (subgraph.empty()) {
return ancestor_clusters;
} else if (outside_compilation_cluster.empty()) {
return SubgraphAndClusterSet();
} else {
SubgraphAndCluster entry;
entry.subgraph = subgraph;
entry.outside_compilation_cluster = outside_compilation_cluster;
return SubgraphAndClusterSet({entry});
}
}
// The transitive union of the ancestor's of this node's inputs. This is only
// saved for debugging in order to print out enough information to debug a
// discovered cycle.
SubgraphAndClusterSet ancestor_clusters;
// The subgraph attr on this node.
string subgraph;
// The outside_compilation attr on this node.
string outside_compilation_cluster;
};
// Adds an edge from ancestor to successor to the cycle detector, and returns an
// error if that edge causes the formation of a cycle. In the error case, logs
// the contents of the node_ancestors_map to facilitate debugging.
Status CheckClusterDependencyForCycles(
const string& ancestor, const string& successor,
const std::unordered_map<string, std::unordered_set<string>>& ancestors,
const std::unordered_map<Node*, PathDetails>& node_ancestors_map,
GraphCycles* cycle_detector, std::map<string, int>* cycle_detector_map) {
if (cycle_detector_map->find(ancestor) == cycle_detector_map->end()) {
(*cycle_detector_map)[ancestor] = cycle_detector->NewNode();
}
if (cycle_detector_map->find(successor) == cycle_detector_map->end()) {
(*cycle_detector_map)[successor] = cycle_detector->NewNode();
}
if (!cycle_detector->InsertEdge((*cycle_detector_map)[ancestor],
(*cycle_detector_map)[successor])) {
LOG(ERROR) << "Cycle in outside_compilation clusters";
for (const auto& cluster : ancestors) {
LOG(ERROR) << "Cluster " << cluster.first << " depends on:";
for (const auto& ancestor : cluster.second) {
LOG(ERROR) << " " << ancestor;
}
}
for (const auto& node_ancestors : node_ancestors_map) {
LOG(ERROR) << "Node " << node_ancestors.first->name() << " ("
<< node_ancestors.second.subgraph << ";"
<< node_ancestors.second.outside_compilation_cluster
<< ") has ancestor clusters:";
for (const auto& ancestor : node_ancestors.second.ancestor_clusters) {
LOG(ERROR) << " " << ancestor.subgraph << ";"
<< ancestor.outside_compilation_cluster;
}
}
return errors::InvalidArgument(
"Can't compile outside_compilation clusters because there is a "
"dependency cycle: see error log for details.");
}
return Status::OK();
}
} // namespace
Status Encapsulator::FindClusterDependencies() {
// Map from nodes to ancestor details. A node is entered into the map if it is
// in a compilation subgraph, and outside_compilation cluster, or appears on a
// path in the outer graph leading from an outside_compilation subgraph.
std::unordered_map<Node*, PathDetails> node_ancestors_map;
// We check that clusters are acyclic using this cycle detector.
GraphCycles cycle_detector;
// Map from cluster name to cycle detector node id.
std::map<string, int> cycle_detector_map;
// Process the nodes in topologically-sorted order.
std::vector<Node*> nodes;
GetReversePostOrder(*graph_in_, &nodes);
for (Node* node : nodes) {
string subgraph_name;
string oc_cluster;
TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &subgraph_name, &oc_cluster));
// First create an entry in the ancestors map if the node is in a compiled
// subgraph or outside_compilation cluster, or if any incoming edge is from
// a node with an ancestor map entry; and find the union of all the
// ancestors.
if (!subgraph_name.empty()) {
node_ancestors_map[node].subgraph = subgraph_name;
node_ancestors_map[node].outside_compilation_cluster = oc_cluster;
}
for (Node* src : node->in_nodes()) {
const auto iter = node_ancestors_map.find(src);
if (iter != node_ancestors_map.end()) {
const auto& ancestors_to_follow = iter->second.AncestorsForSuccessor();
for (const auto& ancestor : ancestors_to_follow) {
if (ancestor.subgraph != subgraph_name ||
ancestor.outside_compilation_cluster != oc_cluster) {
node_ancestors_map[node].ancestor_clusters.insert(ancestor);
}
}
}
}
if (!subgraph_name.empty()) {
// The node is in a compiled subgraph or an outside_compilation cluster.
if (oc_cluster.empty()) {
// The node is not in an outside_compilation cluster. Record the
// subgraph's ancestor dependencies.
for (const auto& cluster : node_ancestors_map[node].ancestor_clusters) {
if (cluster.subgraph != subgraph_name) {
subgraph_ancestors_[subgraph_name].insert(cluster.subgraph);
TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles(
cluster.subgraph, subgraph_name, subgraph_ancestors_,
node_ancestors_map, &cycle_detector, &cycle_detector_map));
}
}
} else {
Subgraph& subgraph = subgraphs_[subgraph_name];
// The node is in an outside_compilation cluster. Record the cluster
// and/or subgraph ancestor dependencies.
for (const auto& cluster : node_ancestors_map[node].ancestor_clusters) {
if (cluster.subgraph == subgraph_name) {
// The ancestor is in the same subgraph.
if (cluster.outside_compilation_cluster != oc_cluster) {
// But not in the same oc_cluster, so record the dependency.
subgraph.RecordOutsideCompilationDependency(
oc_cluster, cluster.outside_compilation_cluster);
TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles(
cluster.outside_compilation_cluster, oc_cluster,
subgraph.OutsideCompilationAncestorMap(), node_ancestors_map,
&cycle_detector, &cycle_detector_map));
}
} else {
// The ancestor is in a different subgraph, so record the
// dependency.
subgraph_ancestors_[subgraph_name].insert(cluster.subgraph);
TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles(
cluster.subgraph, subgraph_name, subgraph_ancestors_,
node_ancestors_map, &cycle_detector, &cycle_detector_map));
}
}
}
}
}
return Status::OK();
}
Status Encapsulator::MakePrunedGraphCopyAndInline(
const Graph& graph, const std::vector<Node*>& sink_nodes,
std::unique_ptr<Graph>* pruned_graph,
@ -2166,6 +2535,7 @@ Status EncapsulateSubgraphsInFunctions(
Encapsulator encapsulator(std::move(group_attribute),
std::move(outside_compilation_attribute),
&graph_in);
TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies());
TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs());
TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(

View File

@ -80,7 +80,7 @@ Status EncapsulateSubgraphsInFunctions(
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
// The attribute that marks function calls produced by the encapsulate
// subgraphs pass and that should in turn be compiled via _XlaLaunch operators.
// subgraphs pass and that should in turn be compiled via XlaLaunch operators.
extern const char* const kXlaCompiledKernelAttr;
// Does `node` have the kXlaCompiledKernelAttr attribute?

View File

@ -20,8 +20,8 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -74,7 +74,7 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
if (!compare(elt_a.first, elt_a.second, iter->second)) {
if (diff) {
*diff = strings::StrCat(map_name, " expected: element with key '",
key_to_string(elt_a.first), " has value '",
key_to_string(elt_a.first), "' has value '",
value_to_string(elt_a.second), "' got: '",
value_to_string(iter->second), "'");
}
@ -121,8 +121,22 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
}
return false;
}
std::unordered_set<string> control_input_a;
std::unordered_set<string> control_input_b;
for (int i = 0; i < a.input_size(); ++i) {
if (a.input(i) != b.input(i)) {
if (str_util::StartsWith(a.input(i), "^")) {
if (!str_util::StartsWith(b.input(i), "^")) {
if (diff) {
*diff = strings::StrCat(
diff_preamble, " mismatch for node ", a.name(), " input ", i,
", expected control input ", a.input(i), " got ", b.input(i),
" expected:\n", a.DebugString(), "\ngot:\n", b.DebugString());
}
return false;
}
control_input_a.insert(a.input(i));
control_input_b.insert(b.input(i));
} else if (a.input(i) != b.input(i)) {
if (diff) {
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
" input ", i, ", expected ", a.input(i),
@ -132,11 +146,29 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
return false;
}
}
if (control_input_a != control_input_b) {
if (diff) {
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
" control inputs differ expected:\n",
a.DebugString(), "\ngot:\n", b.DebugString());
}
return false;
}
return EqualProtoMap<string, AttrValue>(
a.attr(), b.attr(), [](const string& s) { return s; },
[](const AttrValue& v) { return v.DebugString(); },
[](const string& key, const AttrValue& av, const AttrValue& bv) {
return av.DebugString() == bv.DebugString();
if (key == "ancestors") {
// The ancestors are added from a set so the order is unpredictable;
// just compare set equality not list equality.
std::unordered_set<string> a_set(av.list().s().begin(),
av.list().s().end());
std::unordered_set<string> b_set(bv.list().s().begin(),
bv.list().s().end());
return a_set == b_set;
} else {
return av.DebugString() == bv.DebugString();
}
},
strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
diff);
@ -261,6 +293,7 @@ REGISTER_OP("XlaHostCompute")
.Output("outputs: Toutputs")
.Attr("Tinputs: list(type) >= 0")
.Attr("Toutputs: list(type) >= 0")
.Attr("ancestors: list(string) >= 0")
.Attr("key: string")
.Attr("shape_inference_graph: string = ''")
.Attr("shapes: list(shape) >= 0")
@ -899,6 +932,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{"C:o:0", "c:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
@ -1044,17 +1078,20 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"D:o:0", "F:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors",
gtl::ArraySlice<string>({"outside_compilation_O1_host_compute"})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O2"},
{"shapes", gtl::ArraySlice<DataType>({})},
{"_outside_compilation_subgraph", "O2"}},
{"F"}},
{"F", "outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
@ -1193,6 +1230,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{"C:o:0", "D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
@ -1215,6 +1253,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{"G:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F2_O1"},
{"shape_inference_graph", ""},
{"shapes",
@ -1279,6 +1318,179 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two functions to transform, each with one outside_compilation
// cluster, with the dependency between them purely from an outside_compilation
// edge.
TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = InputShaped(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Node* g =
Binary(a, b, b1.opts().WithName("G").WithAttr("_encapsulate", "F2"));
Node* h = Unary(g, b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1")
.WithControlInput(e));
Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
Binary(f, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant =
KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT, DT_FLOAT}, shape.opts());
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected));
}
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant =
KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0"));
Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F2", "O1",
{DT_FLOAT}, shape.opts());
Node* h = Unary(recv, shape.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F2", "O1", {h}, shape.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F2_O1", &library_expected));
}
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
{{"f_0_retval", "F:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
"F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
{
{{"G"}, "BinaryTest", {"a_0_arg", "b_0_arg"}},
{{"I"},
"UnaryTest",
{"outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"G:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F2_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F2_O1"},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"i_0_retval", "I:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = InputShaped(b2.opts().WithName("B"));
Node* key_constant1 =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1",
{DT_FLOAT, DT_FLOAT}, b2.opts());
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
.WithControlInputs({recv1, b})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e},
b2.opts().WithControlInput(e));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
Node* key_constant2 =
KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1",
{DT_FLOAT}, b2.opts());
Node* h = Unary(recv2, b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1")
.WithControlInput(e));
Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h},
b2.opts());
Node* s2 = Sequencer(
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
"F2");
NodeBuilder node_builder2("F2", "F2", lib_def.get());
node_builder2.Input(a).Input(b);
Node* call2 = b2.opts()
.WithControlInputs({s2, call1})
.FinalizeBuilder(&node_builder2);
Binary(call1, call2, b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no inputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
@ -1323,6 +1535,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
{},
{{"Tinputs", gtl::ArraySlice<DataType>({})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
@ -1406,6 +1619,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
{},
{{"Tinputs", gtl::ArraySlice<DataType>({})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
@ -1487,6 +1701,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
{"D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
@ -1567,6 +1782,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
{"D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
@ -1607,6 +1823,371 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two outside_compilation clusters that interact outside the compiled
// subgraph, where the ancestor has no HostCompute Op.
TEST(EncapsulateSubgraphsTest,
OutsideCompilationClusterDependencyNoSrcCluster) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Node* g = Unary(f, b1.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* h = Unary(g, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
Binary(e, h, b1.opts().WithName("I"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* key_constant =
KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0"));
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
{DT_FLOAT}, shape2.opts());
Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected));
}
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
{{"H"},
"UnaryTest",
{"outside_compilation_O2_host_compute:outputs:0"}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"F:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O2"},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"}}},
},
{{"h_0_retval", "H:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* e = Unary(a, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
{DT_FLOAT}, b2.opts());
Node* g = Unary(recv, b2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* send =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts());
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Binary(e, call1, b2.opts().WithName("I"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two outside_compilation clusters that interact outside the compiled
// subgraph, where the successor has no HostCompute Op.
TEST(EncapsulateSubgraphsTest,
OutsideCompilationClusterDependencyNoDstCluster) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
/*Node* g =*/Unary(a, b1.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
Binary(e, h, b1.opts().WithName("I"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant =
KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0"));
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT}, shape1.opts());
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"UnaryTest",
{"outside_compilation_O1_host_compute:outputs:0"}},
{{"H"}, "UnaryTest", {"F:o:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"h_0_retval", "H:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT}, b2.opts());
Node* e = Unary(recv, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
/*Node* g =*/Unary(a, b2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Binary(e, call1, b2.opts().WithName("I"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two outside_compilation clusters that interact outside the compiled
// subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Node* g = Unary(d, b1.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
/*Node* i =*/Binary(d, e,
b1.opts()
.WithName("I")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O3")
.WithControlInput(g));
Binary(e, h, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant =
KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0"));
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT}, shape1.opts());
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts());
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected));
}
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {},
{{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}},
{{"H"}, "UnaryTest", {"F:o:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({})},
{"ancestors",
gtl::ArraySlice<string>({"outside_compilation_O1_host_compute"})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph", ""},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O3_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({})},
{"ancestors",
gtl::ArraySlice<string>({"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"})},
{"key", "host_compute_channel_F1_O3"},
{"shape_inference_graph", ""},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O3"}},
{"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"}}},
{{"h_0_retval", "H:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1",
{DT_FLOAT}, b2.opts());
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts());
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
{DT_FLOAT}, b2.opts());
Node* g = Unary(recv2, b2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3",
{DT_FLOAT}, b2.opts());
/*Node* i =*/Binary(recv3, e,
b2.opts()
.WithName("I")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O3")
.WithControlInput(g));
Node* s1 = Sequencer(b2.opts()
.WithName("F1_sequencer")
.WithControlInputs({recv1, send, recv2, recv3}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Binary(e, call1, b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no outputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
@ -1731,6 +2312,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
{"c:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors", gtl::ArraySlice<string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},

View File

@ -354,6 +354,16 @@ bool GraphCycles::IsReachableNonConst(int32 x, int32 y) {
return reachable;
}
bool GraphCycles::CanContractEdge(int32 a, int32 b) {
CHECK(HasEdge(a, b)) << "No edge exists from " << a << " to " << b;
RemoveEdge(a, b);
bool reachable = IsReachableNonConst(a, b);
// Restore the graph to its original state.
InsertEdge(a, b);
// If reachable, then contracting edge will cause cycle.
return !reachable;
}
bool GraphCycles::ContractEdge(int32 a, int32 b) {
CHECK(HasEdge(a, b));
RemoveEdge(a, b);
@ -388,4 +398,8 @@ std::unordered_set<int32> GraphCycles::Successors(int32 node) {
return rep_->nodes_[node]->out;
}
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) {
return rep_->nodes_[node]->in;
}
} // namespace tensorflow

View File

@ -85,6 +85,9 @@ class GraphCycles {
// and returns false.
bool ContractEdge(int32 a, int32 b);
// Return true if can contract edge, otherwise return false.
bool CanContractEdge(int32 a, int32 b);
// Return whether dest_node is reachable from source_node
// by following edges.
bool IsReachable(int32 source_node, int32 dest_node) const;
@ -115,6 +118,7 @@ class GraphCycles {
bool CheckInvariants() const;
std::unordered_set<int32> Successors(int32 node);
std::unordered_set<int32> Predecessors(int32 node);
// ----------------------------------------------------
struct Rep;

View File

@ -494,6 +494,20 @@ TEST_F(GraphCyclesTest, ContractEdge) {
EXPECT_TRUE(g_.HasEdge(1, 4));
}
TEST_F(GraphCyclesTest, CanContractEdge) {
ASSERT_TRUE(AddEdge(1, 2));
ASSERT_TRUE(AddEdge(1, 3));
ASSERT_TRUE(AddEdge(2, 3));
ASSERT_TRUE(AddEdge(2, 4));
ASSERT_TRUE(AddEdge(3, 4));
EXPECT_FALSE(g_.CanContractEdge(1, 3));
EXPECT_FALSE(g_.CanContractEdge(2, 4));
EXPECT_TRUE(g_.CanContractEdge(1, 2));
EXPECT_TRUE(g_.CanContractEdge(2, 3));
EXPECT_TRUE(g_.CanContractEdge(3, 4));
}
static void BM_StressTest(int iters, int num_nodes) {
while (iters > 0) {
tensorflow::GraphCycles g;

View File

@ -37,30 +37,28 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/util/stream_executor_util.h"
namespace gpu = perftools::gputools;
namespace tensorflow {
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
: OpKernel(ctx), device_type_(ctx->device_type()) {
const NameAttrList* func;
OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
function_ = *func;
DataTypeVector constant_types;
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
num_constant_args_ = constant_types.size();
OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
const std::vector<int>& constants,
const std::vector<int>& resources,
const NameAttrList& function)
: OpKernel(ctx),
constants_(constants),
resources_(resources),
device_type_(ctx->device_type()),
function_(function) {
if (device_type_ == DeviceType(DEVICE_CPU)) {
platform_id_ = gpu::host::kHostPlatformId;
platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
platform_id_ = gpu::cuda::kCudaPlatformId;
platform_id_ = se::cuda::kCudaPlatformId;
} else {
platform_id_ = nullptr;
}
}
Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
XlaCompilationCache** cache) {
Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
XlaCompilationCache** cache) {
const XlaDevice::Metadata* metadata;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
if (s.ok()) {
@ -69,9 +67,9 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
return Status::OK();
}
auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_);
auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_);
if (!platform.ok()) {
return StreamExecutorUtil::ConvertStatus(platform.status());
return platform.status();
}
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
@ -92,15 +90,15 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
return Status::OK();
}
void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaLocalLaunchOp::Compute "
void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaLocalLaunchOpBase::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
ResourceMgr* rm = ctx->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
gpu::Stream* stream =
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
XlaCompilationCache* cache;
@ -114,7 +112,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
const XlaDevice::Metadata* metadata;
const XlaDevice::Metadata* metadata = nullptr;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
bool allocate_xla_tensors = s.ok();
@ -126,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
}
std::map<int, OptionalTensor> variables =
SnapshotResourceVariables(ctx, num_resource_args_);
SnapshotResourceVariables(ctx, resources_);
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
@ -153,27 +151,29 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
options.device_type = &cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
options.device_allocator = xla_allocator;
// TODO(b/77671268): We don't set variable_representation_shape_fn here. This
// is restricted to Variables, but we need something like this to apply to
// normal Tensors too.
if (metadata) {
options.shape_representation_fn = metadata->shape_representation_fn();
}
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
std::map<int, Tensor> constant_args;
for (int i = 0; i < num_constant_args_; ++i) {
for (int i : constants_) {
constant_args.insert({i, ctx->input(i)});
}
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
variables, ctx, &kernel, &executable,
/*compile_options=*/nullptr));
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
OP_REQUIRES_OK(
ctx, cache->Compile(options, function_, constant_args, variables, ctx,
&kernel, &executable, &compile_options));
VLOG(1) << "Executing XLA Computation...";
XlaComputationLaunchContext launch_context(
num_resource_args_, client, xla_allocator, allocate_xla_tensors);
XlaComputationLaunchContext launch_context(client, xla_allocator,
allocate_xla_tensors);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
@ -196,14 +196,69 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Done";
}
namespace {
// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
// in error case, it returns RET instead of void.
#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
do { \
::tensorflow::Status _s(__VA_ARGS__); \
if (!TF_PREDICT_TRUE(_s.ok())) { \
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
return RET; \
} \
} while (0)
// Helper static functions to construct parameters for
// XlaLocalLaunchBase constructor from OpKernelConstruction.
std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
DataTypeVector constant_types;
OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
ctx->GetAttr("Tconstants", &constant_types));
std::vector<int> constants(constant_types.size());
std::iota(constants.begin(), constants.end(), 0);
return constants;
}
std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
DataTypeVector constant_types;
OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
ctx->GetAttr("Tconstants", &constant_types));
DataTypeVector arg_types;
OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
ctx->GetAttr("Targs", &arg_types));
int num_resources;
OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
ctx->GetAttr("Nresources", &num_resources));
std::vector<int> resources(num_resources);
std::iota(resources.begin(), resources.end(),
constant_types.size() + arg_types.size());
return resources;
}
NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
const NameAttrList* func;
OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
return *func;
}
#undef OP_REQUIRES_OK_RETURN
} // namespace
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
: XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
FunctionAttr(ctx)) {}
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
}
REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU),
XlaLocalLaunchOp);
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
REGISTER_KERNEL_BUILDER(Name("_XlaLaunch")
REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
.Device(DEVICE_GPU)
.HostMemory("constants")
.HostMemory("resources"),

View File

@ -26,6 +26,41 @@ limitations under the License.
namespace tensorflow {
// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
// The only difference is that it does not require arguments to follow
// the "constants, then regular args, then resources" order.
// It takes vectors of constant and resource arguments explicitly.
// It does not have corresponding OpDef because it is never present
// in the GraphDef.
// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
// this kernel when asked to create a kernel for an XLA-compiled function.
class XlaLocalLaunchBase : public OpKernel {
public:
XlaLocalLaunchBase(OpKernelConstruction* ctx,
const std::vector<int>& constants,
const std::vector<int>& resources,
const NameAttrList& function);
XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
~XlaLocalLaunchBase() override = default;
void Compute(OpKernelContext* ctx) override;
protected:
// Builds a XlaCompilationCache class suitable for the current device.
Status BuildCompilationCache(OpKernelContext* ctx,
XlaCompilationCache** cache);
// Indexes of compile-time constant inputs
std::vector<int> constants_;
// Indexes of resource inputs
std::vector<int> resources_;
DeviceType device_type_;
NameAttrList function_;
se::Platform::Id platform_id_;
};
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
@ -35,26 +70,12 @@ namespace tensorflow {
// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
// memory.
class XlaLocalLaunchOp : public OpKernel {
class XlaLocalLaunchOp : public XlaLocalLaunchBase {
public:
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
~XlaLocalLaunchOp() override;
void Compute(OpKernelContext* ctx) override;
private:
// Builds a XlaCompilationCache class suitable for the current device.
Status BuildCompilationCache(OpKernelContext* ctx,
XlaCompilationCache** compiler);
DeviceType device_type_;
NameAttrList function_;
int num_constant_args_;
// Number of resource variable arguments.
int num_resource_args_;
perftools::gputools::Platform::Id platform_id_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};

View File

@ -17,7 +17,7 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("_XlaLaunch")
REGISTER_OP("XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
.Input("args: Targs")
@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch")
.Attr("Tresults: list(type) >= 0")
.Attr("function: func")
// XLA random-number generation ops are stateful.
// TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch.
// TODO(phawkins): create stateful and non-stateful variants of XlaLaunch.
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");

View File

@ -122,8 +122,7 @@ Status XlaCompilationCache::BuildSignature(
namespace {
// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch
// op.
// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op.
Status BuildArguments(const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args,
OpKernelContext* ctx,

View File

@ -48,17 +48,16 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable) {
std::map<int, OptionalTensor> variables = GetVariables(ctx);
int64 num_resource_args = variables.size();
xla::LocalClient* client = metadata.client();
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
num_resource_args, client, client->backend().memory_allocator(), true);
client, client->backend().memory_allocator(), true);
launch_context.PopulateInputs(ctx, result, variables);
perftools::gputools::Stream* stream =
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
TF_RET_CHECK(stream);
@ -67,6 +66,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator());
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(ctx->step_id());
auto run_result = executable->Run(launch_context.arguments(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
@ -156,11 +156,14 @@ Status XlaCompileOnDemandOp::Compile(
options.client = metadata.client();
options.flib_def =
new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
options.shape_representation_fn = metadata.shape_representation_fn();
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
result, executable,
/*compile_options=*/nullptr);
result, executable, &compile_options);
}
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {

View File

@ -29,11 +29,8 @@ limitations under the License.
namespace tensorflow {
// An OpKernel that compiles an op to an XLA computation and runs it. Unlike
// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a
// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a
// vanilla TensorFlow op as long as the bridge supports it.
//
// Importantly _XlaLaunch assumes all input and output tensors are on the host,
// whereas XlacompileOnDemandOp works with tensors in device memory.
class XlaCompileOnDemandOp : public OpKernel {
public:
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

View File

@ -50,10 +50,11 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
(void)registrations;
std::unique_ptr<XlaDevice> device;
TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
DEVICE_CPU_XLA_JIT, options, name_prefix,
registration,
/*transfer_as_literal=*/false, &device));
TF_RETURN_IF_ERROR(
XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
/*shape_representation_fn=*/{}, &device));
devices->push_back(device.release());
return Status::OK();
}

View File

@ -48,10 +48,9 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/ptr_util.h"
#include "tensorflow/core/util/stream_executor_util.h"
namespace se = ::perftools::gputools;
namespace tensorflow {
// Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
@ -111,7 +110,9 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) {
bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
<< device_ordinal;
@ -121,7 +122,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
auto platform = se::MultiPlatformManager::PlatformWithName(platform_name);
if (!platform.ok()) {
return StreamExecutorUtil::ConvertStatus(platform.status());
return platform.status();
}
const DeviceAttributes attrs = Device::BuildDeviceAttributes(
@ -130,17 +131,19 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device"));
device->reset(new XlaDevice(options, attrs, device_ordinal,
DeviceType(jit_device_name),
platform.ValueOrDie(), transfer_as_literal));
device->reset(new XlaDevice(
options, attrs, device_ordinal, DeviceType(jit_device_name),
platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
return Status::OK();
}
XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type)
XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform) {}
platform_(platform),
shape_representation_fn_(std::move(shape_representation_fn)) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@ -171,19 +174,28 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return Status::OK();
}
XlaDevice::XlaDevice(const SessionOptions& options,
const DeviceAttributes& attrs, int device_ordinal,
const DeviceType& jit_device_name, se::Platform* platform,
bool transfer_as_literal)
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name),
xla_metadata_(device_ordinal, platform, jit_device_name,
shape_representation_fn),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
transfer_as_literal_(transfer_as_literal) {}
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name;
}
XlaDevice::~XlaDevice() {}
XlaDevice::~XlaDevice() {
if (gpu_device_info_ != nullptr) {
gpu_device_info_->default_context->Unref();
}
}
xla::LocalClient* XlaDevice::client() const {
// We lazily create the client because the platform commits to the
@ -191,9 +203,8 @@ xla::LocalClient* XlaDevice::client() const {
// don't want to do it until we get a chance to hook the platform up
// to a simulator.
// For now GetOrCreateLocalClient always returns success when passed
// a non-null platform. If that changes we may have to plumb in some
// way to pass Status back.
// TODO(b/78468222): This can fail, at least when the backend is GPU and
// there is no GPU on the host.
return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie();
}
@ -218,15 +229,33 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
return stream_.get();
}
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
if (gpu_device_info_ == nullptr) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
// Call GetAllocator for the side-effect of ensuring the allocator
// is created.
GetAllocator({});
// XlaDevice owns both gpu_device_info_ and
// gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
gpu_device_info_->default_context = new XlaDeviceContext(
stream, client(), transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}
return Status::OK();
}
Status XlaDevice::FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) {
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
// Call GetAllocator for the side-effect of ensuring the allocator and
// XlaTensorInfoManager is created.
(void)GetAllocator({});
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_);
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({});
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
shape_representation_fn_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
@ -239,11 +268,10 @@ Status XlaDevice::FillContextMap(const Graph* graph,
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
// When TraceMe profiling is off (which is the default), the
// following TraceMe constructor is simply a conditional test of
// false value. Measurements show that its overhead is negligible.
port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(),
op_kernel->IsExpensive());
// When Xprof profiling is off (which is the default), constructing the
// activity is simple enough that its overhead is negligible.
tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
op_kernel->IsExpensive());
op_kernel->Compute(context);
}
@ -251,8 +279,8 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) {
VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string();
port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(),
op_kernel->IsExpensive());
tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
op_kernel->IsExpensive());
op_kernel->ComputeAsync(context, done);
}
@ -274,7 +302,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
XlaTransferManager manager(stream, client(), transfer_as_literal_);
XlaTransferManager manager(stream, client(), transfer_as_literal_,
shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;

View File

@ -17,8 +17,7 @@ limitations under the License.
// runtime.
//
// Operators assigned to an XlaDevice are compiled into XLA computations.
// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state
// is managed by XLA.
// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
//
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
// under different names (e.g., XLA_CPU or XLA_GPU).
@ -27,6 +26,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device_factory.h"
@ -49,20 +49,25 @@ class XlaDevice : public LocalDevice {
// retrieved e.g., when lazily creating the XlaCompilationCache device.
class Metadata {
public:
Metadata(int device_ordinal, perftools::gputools::Platform* platform,
const DeviceType& device_type);
Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
// The index of the device on this host.
int device_ordinal() const;
perftools::gputools::Platform* platform() const;
se::Platform* platform() const;
xla::LocalClient* client() const;
const DeviceType& jit_device_type() const;
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
return shape_representation_fn_;
}
private:
const int device_ordinal_;
const DeviceType device_type_;
perftools::gputools::Platform* platform_; // Not owned.
se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
@ -76,17 +81,19 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead.
static Status Create(const string& platform_name, const string& device_name,
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
std::unique_ptr<XlaDevice>* device);
static Status Create(
const string& platform_name, const string& device_name,
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
std::unique_ptr<XlaDevice>* device);
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
::perftools::gputools::Platform* platform,
bool transfer_as_literal);
se::Platform* platform, bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
@ -103,7 +110,11 @@ class XlaDevice : public LocalDevice {
Tensor* tensor) override;
xla::LocalClient* client() const;
xla::StatusOr<::perftools::gputools::Stream*> GetStream();
xla::StatusOr<se::Stream*> GetStream();
// If not already set, create and set GpuDeviceInfo.
// Not thread-safe
Status CreateAndSetGpuDeviceInfo();
private:
// The metadata of this XlaDevice.
@ -113,8 +124,8 @@ class XlaDevice : public LocalDevice {
// The name of the device that is used to compile Ops for this XlaDevice.
DeviceType jit_device_name_;
// Memory allocator associated with this device.
Allocator* xla_allocator_; // Not owned.
::perftools::gputools::Platform* platform_; // Not owned.
Allocator* xla_allocator_; // Not owned.
se::Platform* platform_; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
@ -123,6 +134,11 @@ class XlaDevice : public LocalDevice {
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// If set, holds default device context (that we must Unref)
// and its stream.
std::unique_ptr<GpuDeviceInfo> gpu_device_info_;
};
// Builds OpKernel registrations on 'device' for the JIT operators

View File

@ -23,8 +23,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/platform/mem.h"
namespace se = ::perftools::gputools;
namespace tensorflow {
// The allocator used for Tensors assigned to the XLA device.
@ -49,13 +47,14 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(se::Stream* stream,
xla::LocalClient* client,
bool transfer_as_literal)
XlaTransferManager::XlaTransferManager(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: stream_(stream),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal) {}
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)) {}
Status XlaTransferManager::TransferLiteralToDevice(
const Tensor& host_tensor, Tensor* device_tensor) const {
@ -78,7 +77,15 @@ Status XlaTransferManager::TransferLiteralFromDevice(
transfer_manager_->TransferLiteralFromDevice(
stream_->parent(), shaped_buffer));
VLOG(1) << "Transfer from device as literal: " << literal->ToString();
return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor);
Tensor tensor;
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
// Reshape the tensor back to its declared shape.
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
return errors::Internal(
"Tensor::CopyFrom failed when copying from XLA device to CPU");
}
return Status::OK();
}
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
@ -98,9 +105,17 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
TensorShape shape;
if (shape_representation_fn_) {
shape = shape_representation_fn_(device_tensor->shape(),
device_tensor->dtype());
} else {
shape = device_tensor->shape();
}
if (!xla_tensor->has_shaped_buffer()) {
Status s = xla_tensor->AllocateShapedBuffer(
device_tensor->dtype(), device_tensor->shape(), client_,
device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal());
if (!s.ok()) {
done(s);
@ -108,12 +123,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
}
}
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
Status status;
if (transfer_as_literal_) {
status = TransferLiteralToDevice(*cpu_tensor, device_tensor);
Tensor reshaped_cpu_tensor;
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
done(errors::Internal(
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
Status block_status = stream_->BlockHostUntilDone();
@ -173,9 +194,11 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
done(Status::OK());
}
XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
bool transfer_as_literal)
: manager_(stream, client, transfer_as_literal) {}
XlaDeviceContext::XlaDeviceContext(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: manager_(stream, client, transfer_as_literal,
std::move(shape_representation_fn)) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/allocator.h"
@ -45,16 +46,16 @@ class XlaDeviceAllocator : public Allocator {
// Helper class for managing data transfers between host and XLA devices.
class XlaTransferManager {
public:
explicit XlaTransferManager(perftools::gputools::Stream* stream,
xla::LocalClient* client,
bool transfer_as_literal);
explicit XlaTransferManager(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done);
perftools::gputools::Stream* stream() const { return stream_; }
se::Stream* stream() const { return stream_; }
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
@ -64,13 +65,14 @@ class XlaTransferManager {
// Stream obtained from a Device, used to transfer tensors between
// CPU and device.
perftools::gputools::Stream* stream_;
se::Stream* stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
xla::TransferManager* transfer_manager_;
// True if we must use XLA's TransferManager for correct device transfers.
bool transfer_as_literal_;
const bool transfer_as_literal_;
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@ -78,8 +80,9 @@ class XlaTransferManager {
// wraps the methods in XlaTransferManager.
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(perftools::gputools::Stream* stream,
xla::LocalClient* client, bool transfer_as_literal);
explicit XlaDeviceContext(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
@ -87,9 +90,7 @@ class XlaDeviceContext : public DeviceContext {
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
perftools::gputools::Stream* stream() const override {
return manager_.stream();
}
se::Stream* stream() const override { return manager_.stream(); }
private:
XlaTransferManager manager_;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
@ -32,7 +33,7 @@ namespace tensorflow {
// Dummy OpKernel, used for kernels assigned to an XLA device that should be
// compiled. Should never be called at runtime since such ops should be
// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an
// rewritten to a XlaLaunch op. If it is called, it means the placer placed an
// operator on an XLA device but the compiler did not compile it.
class XlaDeviceDummyOp : public OpKernel {
public:
@ -41,7 +42,7 @@ class XlaDeviceDummyOp : public OpKernel {
};
#define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \
REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \
.Device(DEVICE) \
.HostMemory("constants") \
.HostMemory("resources"), \
@ -63,6 +64,9 @@ class XlaDeviceDummyOp : public OpKernel {
ConstantOp); \
REGISTER_KERNEL_BUILDER( \
Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
REGISTER_KERNEL_BUILDER( \
Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \
IdentityNOp); \
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \
PlaceholderOp); \

View File

@ -48,12 +48,22 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
Status status =
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false, &device);
/*transfer_as_literal=*/false,
/*shape_representation_fn=*/{}, &device);
if (!status.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << status;
return Status::OK();
}
// TODO(b/78468222): Uncomment after fixing this bug
// status = device->CreateAndSetGpuDeviceInfo();
// if (!status.ok()) {
// errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT,
// " device");
// return status;
// }
devices->push_back(device.release());
return Status::OK();
}

View File

@ -32,18 +32,19 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/stream_executor_util.h"
namespace gpu = perftools::gputools;
namespace tensorflow {
namespace {
using xla::ScopedShapedBuffer;
using xla::ShapedBuffer;
} // anonymous namespace
std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
int num_variables) {
std::map<int, OptionalTensor> SnapshotResourceVariables(
OpKernelContext* ctx, const std::vector<int>& variables) {
std::map<int, OptionalTensor> snapshot;
int first_variable = ctx->num_inputs() - num_variables;
for (int i = 0; i < num_variables; ++i) {
for (int i : variables) {
Var* variable = nullptr;
ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
OptionalTensor& tensor = snapshot[first_variable + i];
ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
@ -54,74 +55,78 @@ std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
return snapshot;
}
XlaAllocator::XlaAllocator(const gpu::Platform* platform, Allocator* wrapped)
XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
: xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
XlaAllocator::~XlaAllocator() {}
xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
void* data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size);
AllocationAttributes attrs;
attrs.no_retry_on_failure = !retry_on_failure;
void* data =
wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
if (data == nullptr) {
return errors::ResourceExhausted("Out of memory while trying to allocate ",
size, " bytes.");
} else {
return gpu::DeviceMemoryBase(data, size);
}
return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
device_ordinal, this);
}
Status XlaAllocator::Deallocate(int device_ordinal,
gpu::DeviceMemoryBase* mem) {
wrapped_->DeallocateRaw(mem->opaque());
Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
wrapped_->DeallocateRaw(mem.opaque());
return Status::OK();
}
namespace {
namespace internal {
// Return the 'index''th subtree of the given ShapedBuffer as a
// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
// subtree, and sets the input's buffer pointers to nullptr for the subtree.
std::unique_ptr<xla::ScopedShapedBuffer> ExtractSubShapedBuffer(
xla::ShapedBuffer* shaped_buffer, int index,
ScopedShapedBuffer ExtractSubShapedBuffer(
ShapedBuffer* shaped_buffer, int index,
xla::DeviceMemoryAllocator* allocator) {
xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape(
const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape(
shaped_buffer->on_host_shape(), index);
xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape(
const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape(
shaped_buffer->on_device_shape(), index);
xla::ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape,
shaped_buffer->platform(),
shaped_buffer->device_ordinal());
ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape,
shaped_buffer->platform(),
shaped_buffer->device_ordinal());
auto& shape_tree = shaped_buffer->buffers();
auto& sub_shape_tree = sub_shaped_buffer.buffers();
sub_shape_tree.CopySubtreeFrom(shape_tree,
/*source_base_index=*/{index},
/*target_base_index=*/{});
for (auto& index_to_buffer : shape_tree) {
if (!index_to_buffer.first.empty() && index_to_buffer.first[0] == index) {
index_to_buffer.second = gpu::DeviceMemoryBase(nullptr, 0);
}
}
return xla::ScopedShapedBuffer::MakeScoped(&sub_shaped_buffer, allocator)
.ValueOrDie();
shape_tree.ForEachMutableElement(
[index](const xla::ShapeIndex& shape_index,
tensorflow::se::DeviceMemoryBase* data) {
// shape_index is empty for the root node. Ignore that.
if (!shape_index.empty() && shape_index[0] == index) {
*data = tensorflow::se::DeviceMemoryBase(nullptr, 0);
}
});
return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator);
}
} // namespace
} // namespace internal
using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
int64 num_resource_args, xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors)
: num_resource_args_(num_resource_args),
client_(client),
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors)
: client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors) {}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
const std::map<int, OptionalTensor>& variables) {
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers_.resize(kernel->xla_input_shapes.size());
arg_ptrs_ = std::vector<xla::ShapedBuffer*>(arg_buffers_.size());
arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
// Pass remaining parameters.
const Tensor* t;
@ -140,16 +145,15 @@ void XlaComputationLaunchContext::PopulateInputs(
if (xla::ShapeUtil::IsTuple(on_device_shape)) {
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
arg_ptrs_[i] =
const_cast<xla::ShapedBuffer*>(&xla_tensor->shaped_buffer());
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
} else {
CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
<< "On-device shape "
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
<< " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape);
gpu::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_[i] = xla::MakeUnique<xla::ShapedBuffer>(
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_[i] = xla::MakeUnique<ShapedBuffer>(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
@ -160,15 +164,15 @@ void XlaComputationLaunchContext::PopulateInputs(
void XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
std::unique_ptr<xla::ScopedShapedBuffer> output) {
gpu::Stream* stream =
ScopedShapedBuffer output) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Computation output should always be a tuple.
if (VLOG_IS_ON(2)) {
VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString();
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
VLOG(2) << "Result tuple shape (on device): "
<< output->on_device_shape().DebugString();
<< output.on_device_shape().DebugString();
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
@ -191,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs(
OP_REQUIRES_OK(
ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer(
const_tensor.dtype(), const_tensor.shape(),
client_, stream->parent()->device_ordinal()));
}
Device* device = dynamic_cast<Device*>(ctx->device());
OP_REQUIRES(ctx, device != nullptr,
@ -226,18 +225,18 @@ void XlaComputationLaunchContext::PopulateOutputs(
const TensorShape& shape = kernel->outputs[i].shape;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
gpu::DeviceMemoryBase buffer = output->buffer({output_num});
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(output.get(), output_num, xla_allocator_));
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
output->set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num});
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
@ -257,7 +256,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
errors::Internal("Invalid input index for variable write."));
gpu::DeviceMemoryBase buffer = output->buffer({output_num});
se::DeviceMemoryBase buffer = output.buffer({output_num});
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
@ -282,12 +281,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(output.get(), output_num, xla_allocator_));
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
*variable->tensor() = output_tensor;
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
write.type, write.shape, buffer, allocator);
output->set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num});
output.set_buffer(xla::OwningDeviceMemory(), {output_num});
*variable->tensor() = output_tensor;
}
++output_num;

View File

@ -22,6 +22,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@ -31,28 +33,28 @@ limitations under the License.
namespace tensorflow {
class XlaAllocator;
// Takes a snapshot of the values of resource variable arguments, which are
// the last `num_variables` arguments. We snapshot tensors that back
// Takes a snapshot of the values of resource variable arguments, whose
// indices are specified in `variables` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
// Returns a map of TensorFlow argument index to resource variable.
std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
int num_variables);
// Returns a map of TensorFlow argument index to resource variable. If a
// resource variable is not initialized, the corresponding OptionalTensor
// will have its `present` field set to false.
std::map<int, OptionalTensor> SnapshotResourceVariables(
OpKernelContext* ctx, const std::vector<int>& variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
// see comment on `AllowsAsynchronousDeallocation()`.
class XlaAllocator : public xla::DeviceMemoryAllocator {
public:
XlaAllocator(const perftools::gputools::Platform* platform,
Allocator* wrapped);
XlaAllocator(const se::Platform* platform, Allocator* wrapped);
~XlaAllocator() override;
xla::StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
xla::StatusOr<xla::OwningDeviceMemory> Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) override;
Status Deallocate(int device_ordinal,
perftools::gputools::DeviceMemoryBase* mem) override;
Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
@ -74,7 +76,7 @@ class XlaComputationLaunchContext {
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
// output tensors and variables are always XlaTensors. If false they are
// assumed to be "normal" device pointers.
XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors);
@ -87,14 +89,13 @@ class XlaComputationLaunchContext {
// Given the XLA output in `output`, populate all outputs of `ctx`.
void PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
std::unique_ptr<xla::ScopedShapedBuffer> output);
xla::ScopedShapedBuffer output);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
private:
int64 num_resource_args_;
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
@ -126,8 +127,7 @@ class XlaTensorBuffer : public TensorBuffer {
}
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
perftools::gputools::DeviceMemoryBase buffer,
Allocator* allocator) {
se::DeviceMemoryBase buffer, Allocator* allocator) {
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
buffer.size(), allocator);
@ -143,6 +143,17 @@ class XlaTensorBuffer : public TensorBuffer {
Allocator* allocator_;
};
// Exposed in this header file for microbenchmarking purposes, but this is an
// internal implementation detail.
namespace internal {
// Return the 'index''th subtree of the given ShapedBuffer as a
// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the
// subtree, and sets the input's buffer pointers to nullptr for the subtree.
xla::ScopedShapedBuffer ExtractSubShapedBuffer(
xla::ShapedBuffer* shaped_buffer, int index,
xla::DeviceMemoryAllocator* allocator);
} // namespace internal
} // namespace tensorflow
#endif

View File

@ -0,0 +1,64 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Contains microbenchmarks for performance critical functions in
// xla_launch_util.cc.
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs
// (cardinality of each non-leaf node's children).
void BM_ExtractSubBuffer(int iters, int depth, int fan_out) {
tensorflow::testing::StopTiming();
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128});
for (int i = 0; i < depth; ++i) {
std::vector<xla::Shape> shapes(fan_out, shape);
shape = xla::ShapeUtil::MakeTupleShape(shapes);
}
xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr,
/*device_ordinal=*/0);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
// Extract a buffer from approximately the middle of the first level of the
// tree.
(void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer,
/*index=*/fan_out / 2,
/*allocator=*/nullptr)
.release();
}
}
BENCHMARK(BM_ExtractSubBuffer)
->ArgPair(1, 4)
->ArgPair(1, 8)
->ArgPair(1, 32)
->ArgPair(1, 64)
->ArgPair(1, 128)
->ArgPair(1, 256)
->ArgPair(1, 512)
->ArgPair(2, 4)
->ArgPair(2, 8)
->ArgPair(2, 32)
->ArgPair(2, 64)
->ArgPair(2, 128);
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
tensorflow::testing::RunBenchmarks();
return RUN_ALL_TESTS();
}

View File

@ -31,16 +31,15 @@ namespace tensorflow {
return FromTensor(const_cast<Tensor*>(tensor));
}
/*static*/ perftools::gputools::DeviceMemoryBase
XlaTensor::DeviceMemoryFromTensor(const Tensor& tensor) {
/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
const Tensor& tensor) {
const XlaTensor* xla_tensor = FromTensor(&tensor);
if (xla_tensor) {
CHECK(xla_tensor->has_shaped_buffer());
return xla_tensor->shaped_buffer().root_buffer();
} else {
return perftools::gputools::DeviceMemoryBase(
const_cast<char*>(tensor.tensor_data().data()),
tensor.tensor_data().size());
return se::DeviceMemoryBase(const_cast<char*>(tensor.tensor_data().data()),
tensor.tensor_data().size());
}
}
@ -53,22 +52,22 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
client->backend().transfer_manager()->HostShapeToDeviceShape(
on_host_shape);
xla::ShapedBuffer buffer(on_host_shape, on_device_shape, client->platform(),
device_ordinal);
for (auto& index_to_buffer : buffer.buffers()) {
xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape,
client->backend().memory_allocator(),
device_ordinal);
for (auto& index_to_buffer : shaped_buffer.buffers()) {
xla::Shape subshape =
xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
uint64 size =
client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
TF_ASSIGN_OR_RETURN(index_to_buffer.second,
TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
client->backend().memory_allocator()->Allocate(
device_ordinal, size, /*retry_on_failure=*/false));
// Move our buffer into shaped_buffer, which takes ownership of it.
index_to_buffer.second = buffer.Forget();
}
TF_ASSIGN_OR_RETURN(auto scoped_buffer,
xla::ScopedShapedBuffer::MakeScoped(
&buffer, client->backend().memory_allocator()));
set_shaped_buffer(std::move(scoped_buffer));
set_shaped_buffer(std::move(shaped_buffer));
return Status::OK();
}

View File

@ -43,8 +43,7 @@ class XlaTensor {
// which case the returned value is shaped_buffer()->root_buffer(), or a
// normal Tensor in which case the returned value is
// {tensor.tensor_data().data(), tensor.tensor_data().size}.
static perftools::gputools::DeviceMemoryBase DeviceMemoryFromTensor(
const Tensor& tensor);
static se::DeviceMemoryBase DeviceMemoryFromTensor(const Tensor& tensor);
// Assign the internal ShapedBuffer to new memory for the given dtype and
// shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it
@ -55,7 +54,7 @@ class XlaTensor {
// Some Tensors can have complex on-device shapes, including tuple shapes. To
// manage the memory for these tensors a ShapedBuffer may be required.
// Return true if this TensorInfo contains a ShapedBuffer.
// Return true if this XlaTensor contains a ShapedBuffer.
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
// Return the contained ShapedBuffer.
// REQUIRES: has_shaped_buffer()
@ -63,17 +62,17 @@ class XlaTensor {
CHECK(has_shaped_buffer());
return *shaped_buffer_;
}
// Mutates the TensorInfo to set the ShapedBuffer.
void set_shaped_buffer(
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer) {
shaped_buffer_ = std::move(shaped_buffer);
// Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
xla::MakeUnique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
}
// Some tensors on the device may have known values on the host. We use these
// in on-demand mode to avoid re-copying values from the device if we know the
// host value already.
// Return true if this TensorInfo contains a host tensor.
// Return true if this XlaTensor contains a host tensor.
bool has_host_tensor() const { return host_tensor_ != nullptr; }
// Return the contained host tensor.
// REQUIRES: has_host_tensor()

View File

@ -42,7 +42,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform",
"//tensorflow/python:random_seed",
"//tensorflow/python:session",
@ -58,7 +58,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -72,7 +72,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -93,7 +93,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -111,7 +111,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:bitwise_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
@ -127,7 +127,7 @@ tf_xla_py_test(
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
@ -141,7 +141,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -156,7 +156,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -170,7 +170,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -184,7 +184,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:gradient_checker",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
@ -209,7 +209,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:gradient_checker",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
@ -225,7 +225,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@ -241,7 +241,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@ -263,7 +263,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@ -291,7 +291,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -300,14 +300,41 @@ tf_xla_py_test(
name = "extract_image_patches_op_test",
size = "small",
srcs = ["extract_image_patches_op_test.py"],
tags = [
"manual",
"notap",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "eager_test",
size = "small",
srcs = ["eager_test.py"],
disabled_backends = [
# TODO(b/78199195) Support XLA CPU devices in eager runtime
"cpu",
"cpu_ondemand",
# TODO(b/78468222) Enable GPU backend
"gpu",
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:platform_test",
"//tensorflow/python/eager:function",
],
)
tf_xla_py_test(
name = "fft_test",
size = "medium",
@ -319,7 +346,7 @@ tf_xla_py_test(
"//tensorflow/contrib/signal:signal_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:spectral_ops",
],
@ -333,19 +360,19 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "ftrl_test",
size = "small",
size = "medium",
srcs = ["ftrl_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -361,7 +388,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -376,12 +403,27 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:image_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "listdiff_op_test",
size = "small",
srcs = ["listdiff_op_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform_test",
"@six_archive//:six",
],
)
tf_xla_py_test(
name = "lrn_ops_test",
size = "medium",
@ -389,7 +431,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
@ -404,7 +446,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -416,7 +458,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -430,7 +472,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -443,7 +485,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -456,7 +498,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
@ -471,7 +513,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
@ -488,7 +530,7 @@ tf_xla_py_test(
],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
@ -503,7 +545,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -519,7 +561,7 @@ tf_xla_py_test(
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -532,7 +574,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
],
)
@ -544,7 +586,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -556,7 +598,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -571,7 +613,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -584,7 +626,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:platform_test",
@ -599,7 +641,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -615,7 +657,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -628,7 +670,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/contrib/stateless",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -642,7 +684,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
@ -661,7 +703,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -674,7 +716,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@ -688,7 +730,7 @@ tf_xla_py_test(
srcs = ["fused_batchnorm_test.py"],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn",
@ -707,7 +749,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
@ -726,7 +768,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
@ -736,11 +778,12 @@ tf_xla_py_test(
name = "gather_test",
size = "medium",
srcs = ["gather_test.py"],
tags = ["noasan"], # times out, http://b/78599043
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -752,7 +795,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -765,21 +808,34 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "xla_device_test",
name = "xla_device_gpu_test",
size = "small",
srcs = ["xla_device_test.py"],
srcs = ["xla_device_gpu_test.py"],
additional_deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
],
)
@ -796,15 +852,23 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
],
# TODO(b/62961789): Test fails with SIGABRT
tags = [
"manual",
"notap",
)
cuda_py_test(
name = "dense_layer_test",
size = "small",
srcs = ["dense_layer_test.py"],
additional_deps = [
"//tensorflow/contrib/compiler:compiler_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:layers",
"//tensorflow/python:variables",
],
)
@ -847,7 +911,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
@ -862,7 +926,7 @@ cuda_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
@ -900,7 +964,19 @@ tf_xla_py_test(
srcs = ["fake_quant_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "placeholder_test",
size = "small",
srcs = ["placeholder_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)

View File

@ -29,51 +29,70 @@ from tensorflow.python.platform import test
class ArgMinMaxTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, inp, expected):
"""Verifies that 'op' produces 'expected' when fed input 'inp' .
def _assertOpOutputMatchesExpected(self, op, axis, output_type, op_input,
expected):
"""Verifies that 'op' produces 'expected' when fed input 'op_input' .
Args:
op: operator to test
inp: numpy input array to use as input to 'op'.
op: argmin or argmax operator to test.
axis: integer axis to reduce across.
output_type: numpy datatype of the output to produce.
op_input: numpy input array to use as input to 'op'.
expected: numpy array representing the expected output of 'op'.
"""
with self.test_session() as session:
with self.test_scope():
pinp = array_ops.placeholder(
dtypes.as_dtype(inp.dtype), inp.shape, name="a")
output = op(pinp)
result = session.run(output, {pinp: inp})
dtypes.as_dtype(op_input.dtype), op_input.shape, name="a")
output = op(pinp, axis=axis, output_type=output_type)
result = session.run(output, {pinp: op_input})
self.assertAllEqual(result, expected)
def testArgMinMax(self):
# Complex numbers do not support argmin/argmax.
minmax_types = set(self.numeric_types) - set(self.complex_types)
for dtype in minmax_types:
self._assertOpOutputMatchesExpected(
lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32),
np.array([1, 10, 27, 3, 3, 4], dtype=dtype),
expected=np.int32(2))
self._assertOpOutputMatchesExpected(
lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32),
np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype),
expected=np.array([0, 1, 0], dtype=np.int32))
self._assertOpOutputMatchesExpected(
lambda x: math_ops.argmax(x, axis=1, output_type=dtypes.int32),
np.array([[4, 1], [3, 2]], dtype=dtype),
expected=np.array([0, 0], dtype=np.int32))
# output_type is a numpy data type that is used to specify the desired
# output type of the op as well as to convert the Python number to the
# array scalar of the type.
for output_type in self.int_types:
self._assertOpOutputMatchesExpected(
math_ops.argmax,
axis=0,
output_type=output_type,
op_input=np.array([1, 10, 27, 3, 3, 4], dtype=dtype),
expected=output_type(2))
self._assertOpOutputMatchesExpected(
math_ops.argmax,
axis=0,
output_type=output_type,
op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype),
expected=np.array([0, 1, 0], dtype=output_type))
self._assertOpOutputMatchesExpected(
math_ops.argmax,
axis=1,
output_type=output_type,
op_input=np.array([[4, 1], [3, 2]], dtype=dtype),
expected=np.array([0, 0], dtype=output_type))
self._assertOpOutputMatchesExpected(
lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32),
np.array([3, 10, 27, 3, 2, 4], dtype=dtype),
expected=np.int32(4))
self._assertOpOutputMatchesExpected(
lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32),
np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype),
expected=np.array([1, 0, 1], dtype=np.int32))
self._assertOpOutputMatchesExpected(
lambda x: math_ops.argmin(x, axis=1, output_type=dtypes.int32),
np.array([[4, 1], [3, 2]], dtype=dtype),
expected=np.array([1, 1], dtype=np.int32))
self._assertOpOutputMatchesExpected(
math_ops.argmin,
axis=0,
output_type=output_type,
op_input=np.array([3, 10, 27, 3, 2, 4], dtype=dtype),
expected=output_type(4))
self._assertOpOutputMatchesExpected(
math_ops.argmin,
axis=0,
output_type=output_type,
op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype),
expected=np.array([1, 0, 1], dtype=output_type))
self._assertOpOutputMatchesExpected(
math_ops.argmin,
axis=1,
output_type=output_type,
op_input=np.array([[4, 1], [3, 2]], dtype=dtype),
expected=np.array([1, 1], dtype=output_type))
if __name__ == "__main__":

View File

@ -0,0 +1,135 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DenseLayer JIT compilation on the CPU and GPU devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.contrib.compiler import jit
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.layers import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
jit_scope = jit.experimental_jit_scope
def GetRunMetadataLabels(run_metadata):
"""Returns all labels in run_metadata."""
labels = []
for dev_stats in run_metadata.step_stats.dev_stats:
for node_stats in dev_stats.node_stats:
labels.append(node_stats.timeline_label)
return labels
def InLabels(labels, substr):
"""Returns true iff one of the labels contains substr."""
return any([substr in x for x in labels])
def XlaLaunchOpCount(labels):
"""Count how many XlaLaunch labels are present."""
return sum("XlaLaunch(" in x for x in labels)
class DenseLayerTest(test.TestCase):
def testDenseLayerAutoJit(self):
"""Tests dense layer compilation in auto-jit mode.
Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
"""
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit")
config = config_pb2.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
config_pb2.OptimizerOptions.ON_1)
with self.test_session(config=config) as sess:
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
y = layers.dense(x, 3)
sess.run(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
sess.run(
y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(1, XlaLaunchOpCount(labels))
self.assertFalse(InLabels(labels, "ListDiff"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
Dense layer with static shape input tensor should be compiled into a single
XlaLaunch op by XLA.
"""
with self.test_session() as sess:
x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32)
with jit_scope():
y = layers.dense(x, 3)
sess.run(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
sess.run(
y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(1, XlaLaunchOpCount(labels))
# No need to check whether ListDiff is compiled or not because ListDiff op
# is not used when input tensor shape is fully defined.
def testDenseLayerJitScopeUndefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
Dense layer uses shape op to get shape of input tensor if its shape is not
fully defined. XLA does not cluster shape op with other operators. But in
experimental_jit_scope, XLA is forced to compile shape op into its own
cluster, causing dense layer to be split into TWO XlaLaunch ops.
"""
with self.test_session() as sess:
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
with jit_scope():
y = layers.dense(x, 3)
sess.run(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
sess.run(
y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(2, XlaLaunchOpCount(labels))
self.assertFalse(InLabels(labels, "ListDiff"))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,309 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for eager execution using XLA."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import convolutional
from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import googletest
class EagerTest(XLATestCase):
def testBasic(self):
with self.test_scope():
three = constant_op.constant(3)
five = constant_op.constant(5)
product = three * five
self.assertAllEqual(15, product)
def testExecuteListOutputLen0(self):
with self.test_scope():
empty = constant_op.constant([], dtype=dtypes.float32)
result = array_ops.unstack(empty, 0)
self.assertTrue(isinstance(result, list))
self.assertEqual(0, len(result))
def testExecuteListOutputLen1(self):
with self.test_scope():
split_dim = constant_op.constant(1)
value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 1, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(1, len(result))
self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0])
def testExecuteListOutputLen3(self):
with self.test_scope():
split_dim = constant_op.constant(1)
value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 3, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(3, len(result))
self.assertAllEqual([[0], [3]], result[0])
self.assertAllEqual([[1], [4]], result[1])
self.assertAllEqual([[2], [5]], result[2])
def testBasicGraph(self):
# Run some ops eagerly
with self.test_scope():
three = constant_op.constant(3)
five = constant_op.constant(5)
product = three * five
self.assertAllEqual(15, product)
# Run some ops graphly
with context.graph_mode(), self.test_session() as sess:
with self.test_scope():
three = constant_op.constant(3)
five = constant_op.constant(5)
product = three * five
self.assertAllEqual(15, sess.run(product))
def testDegenerateSlices(self):
with self.test_scope():
npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
t = constant_op.constant(npt)
# degenerate by offering a forward interval with a negative stride
self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :])
# degenerate with a reverse interval with a positive stride
self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :])
# empty interval in every dimension
self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1])
def testIdentity(self):
with self.test_scope():
self.assertAllEqual(2, array_ops.identity(2))
def testIdentityOnVariable(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(True)
i = array_ops.identity(v)
self.assertAllEqual(True, i.numpy())
def testAssignAddVariable(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(1.0)
v.assign_add(2.0)
self.assertEqual(3.0, v.numpy())
def testGradient(self):
def f(x):
return x
with self.test_scope():
grad_fn = backprop.gradients_function(f)
self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
def testVariableGradient(self):
with self.test_scope():
v0 = resource_variable_ops.ResourceVariable(1.0)
def f():
x = v0 * v0
return x
grads = backprop.implicit_grad(f)()
self.assertEqual(2., grads[0][0].numpy())
class EagerFunctionTest(XLATestCase):
def testBasic(self):
with self.test_scope():
matmul = function.defun(math_ops.matmul, compiled=True)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t, transpose_a=True)
self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
def testConv(self):
if 'GPU' in self.device:
# TODO(b/32333178)
self.skipTest('Current implementation of RandomStandardNormal kernel '
'is very slow on GPU, and has been blacklisted.')
with self.test_scope():
data_format = 'channels_last'
conv = convolutional.Conv2D(
filters=1, kernel_size=2, padding='VALID',
data_format=data_format, activation=nn_ops.relu,
kernel_initializer=init_ops.ones_initializer(),
bias_initializer=init_ops.zeros_initializer())
pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
def model(x):
x = conv(x)
return pool(x)
model = function.defun(model, compiled=True)
x = array_ops.ones([1, 4, 4, 1])
y = model(x)
self.assertAllEqual(y.numpy(), [[[[4.]]]])
def testReadVariable(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(1.0)
@function.defun(compiled=True)
def f():
return v.read_value()
var = f()
self.assertEqual(1.0, var.numpy())
def testUpdateVariable(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(1.0)
def f(v):
v.assign_add(1.0)
return v
f = function.defun(f, compiled=True)
var = f(v)
self.assertEqual(2.0, var.numpy())
def testAllArgumentKinds(self):
"""Test a complex function that takes different argument kinds.
tf2xla machinery that translates, compiles, and runs defuns
classifies arguments into: compile-time constants, regular tensors,
and resources. This test creates a function with a mix of all these
kinds. Moreover, the order of function arguments is intentionally mixed up.
This also tests the case when the same argument is a compile-time constant
as well as used in an operation that normally expects its inputs to be
in device memory - addition in this case.
"""
with self.test_scope():
def foo(c1, r1, v1, c2, v2, r2):
# c1 and c2 are compile-time constants
# r1 and r2 are regular tensors
# v1 and v2 are resource variables
a = c1 + r1
b = math_ops.cast(c2, dtypes.float32) + v2
c = array_ops.slice(v1, c1, c2)
d = r2 * v2
return a, b, c, d
foo = function.defun(foo, compiled=True)
c1 = [0, 0]
c2 = array_ops.ones([2], dtype=dtypes.int32)
r1 = array_ops.ones([2])
r2 = [[2., 2.], [3., 3.]]
v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
self.assertAllEqual([1, 1], a.numpy())
self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
self.assertAllEqual([[1.]], c.numpy())
self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
def testDefunInGradientTape(self):
with self.test_scope():
v0 = resource_variable_ops.ResourceVariable(5.0)
@function.defun(compiled=True)
def f(x):
x = v0 * v0 * x
return x
x = constant_op.constant(3.0)
with backprop.GradientTape() as tape:
y = f(x)
dy = tape.gradient(y, v0)
self.assertEqual(75, y.numpy())
self.assertEqual(30, dy.numpy())
class ExcessivePaddingTest(XLATestCase):
"""Test that eager execution works with TPU flattened tensors.
Tensors that would normally be excessively padded when written
to TPU memory are reshaped to 1-D flat tensors.
This test case verifies that such tensors work with eager execution.
The flattening currently only happens on TPU, but tests should work
fine with all backends as flattening is transparent.
"""
def testFromConstant(self):
with self.test_scope():
# Create constant of shape [100, 2, 1]. This tensor would be
# excessively padded on TPU.
tensor = constant_op.constant(100 * [[[10.0], [2.0]]])
# Use reduce_sum since it requires correctly working with
# a particular dimension.
reduced = math_ops.reduce_sum(tensor, axis=1)
self.assertAllEqual(100 * [[12.0]], reduced)
def testFromOperation(self):
with self.test_scope():
tensor = array_ops.ones([3, 100, 2, 2])
reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3])
self.assertAllEqual(100 * [12.0], reduced)
def testAsFunctionInput(self):
with self.test_scope():
@function.defun(compiled=True)
def f(x):
return math_ops.reduce_sum(x, axis=2)
tensor = constant_op.constant(100 * [[[10.0, 2.0]]])
reduced = f(tensor)
self.assertAllEqual(100 * [[12.0]], reduced)
def testAsFunctionOutput(self):
with self.test_scope():
@function.defun(compiled=True)
def f(x):
return x * constant_op.constant(100 * [[[10.0, 2.0]]])
y = f(3)
reduced = math_ops.reduce_sum(y, axis=2)
self.assertAllEqual(100 * [[36.0]], reduced)
if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))
googletest.main()

View File

@ -24,12 +24,10 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
@test_util.with_c_api
class FunctionTest(XLATestCase):
def testFunction(self):

View File

@ -34,6 +34,13 @@ from tensorflow.python.ops import image_ops
from tensorflow.python.platform import test
def GenerateNumpyRandomRGB(shape):
# Only generate floating points that are fractions like n / 256, since they
# are RGB pixels. Some low-precision floating point types in this test can't
# handle arbitrary precision floating points well.
return np.random.randint(0, 256, shape) / 256.
class RGBToHSVTest(XLATestCase):
def testBatch(self):
@ -43,7 +50,7 @@ class RGBToHSVTest(XLATestCase):
shape = (batch_size, 2, 7, 3)
for nptype in self.float_types:
inp = np.random.rand(*shape).astype(nptype)
inp = GenerateNumpyRandomRGB(shape).astype(nptype)
# Convert to HSV and back, as a batch and individually
with self.test_session() as sess:
@ -83,7 +90,7 @@ class RGBToHSVTest(XLATestCase):
def testRGBToHSVNumpy(self):
"""Tests the RGB to HSV conversion matches a reference implementation."""
for nptype in self.float_types:
rgb_flat = np.random.random(64 * 3).reshape((64, 3)).astype(nptype)
rgb_flat = GenerateNumpyRandomRGB((64, 3)).astype(nptype)
rgb_np = rgb_flat.reshape(4, 4, 4, 3)
hsv_np = np.array([
colorsys.rgb_to_hsv(

View File

@ -78,10 +78,10 @@ def InLabels(labels, substr):
def MetadataHasXlaLaunch(run_metadata):
"""Returns true if there is a _XlaLaunch kernel in run_metadata's timeline."""
"""Returns true if there is a XlaLaunch kernel in run_metadata's timeline."""
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch")
return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch")
class JitLaunchTest(test.TestCase):
@ -90,8 +90,8 @@ class JitLaunchTest(test.TestCase):
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
# the same number of tensors as arguments that are in 'args', and must return
# a tuple of output tensors.
# If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node
# actually ran. However, it is sometimes possible for _XlaLaunch ops to be
# If 'require_kernel_launch' is True, then we verify that a XlaLaunch node
# actually ran. However, it is sometimes possible for XlaLaunch ops to be
# constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
@ -441,14 +441,14 @@ class XlaCompilationTest(test.TestCase):
self.assertFalse(InLabels(labels, "Log"))
self.assertTrue(InLabels(labels, "Reciprocal"))
self.assertTrue(InLabels(labels, "Mul"))
self.assertFalse(InLabels(labels, "_XlaLaunch"))
self.assertFalse(InLabels(labels, "XlaLaunch"))
# Compile the backprop. One _XlaLaunch.
# Compile the backprop. One XlaLaunch.
labels = _Run(compiled=True)
self.assertFalse(InLabels(labels, "Log"))
self.assertFalse(InLabels(labels, "Reciprocal"))
self.assertFalse(InLabels(labels, "Mul"))
self.assertTrue(InLabels(labels, "_XlaLaunch"))
self.assertTrue(InLabels(labels, "XlaLaunch"))
class ElementWiseFusionTest(test.TestCase):
@ -482,14 +482,15 @@ class ElementWiseFusionTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = RunMetadataLabels(run_metadata)
count = sum("_XlaLaunch(" in x for x in labels)
count = sum("XlaLaunch(" in x for x in labels)
return output, count
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
arg1 = np.random.rand(2, 2).astype(np.float32)
os.environ["TF_XLA_FLAGS"] = "--tf_xla_fusion_only=true"
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
"--tf_xla_cpu_global_jit")
tf_op, tf_count = self.simpleTest(arg0, arg1,
config_pb2.OptimizerOptions.OFF)
self.assertEqual(0, tf_count)

View File

@ -0,0 +1,101 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XLA listdiff operator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ListDiffTest(xla_test.XLATestCase):
def _testListDiff(self, x, y, out, idx):
for dtype in [dtypes.int32, dtypes.int64]:
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session() as sess:
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
with self.test_scope():
out_tensor, idx_tensor = array_ops.listdiff(
x_tensor, y_tensor, out_idx=index_dtype)
tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
self.assertAllEqual(out, tf_out)
self.assertAllEqual(idx, tf_idx)
self.assertEqual(1, out_tensor.get_shape().ndims)
self.assertEqual(1, idx_tensor.get_shape().ndims)
def testBasic1(self):
self._testListDiff(x=[1, 2, 3, 4], y=[1, 2], out=[3, 4], idx=[2, 3])
def testBasic2(self):
self._testListDiff(x=[1, 2, 3, 4], y=[2], out=[1, 3, 4], idx=[0, 2, 3])
def testBasic3(self):
self._testListDiff(x=[1, 4, 3, 2], y=[4, 2], out=[1, 3], idx=[0, 2])
def testDuplicates(self):
self._testListDiff(x=[1, 2, 4, 3, 2, 3, 3, 1],
y=[4, 2],
out=[1, 3, 3, 3, 1],
idx=[0, 3, 5, 6, 7])
def testRandom(self):
num_random_tests = 10
int_low = -7
int_high = 8
max_size = 50
for _ in xrange(num_random_tests):
x_size = np.random.randint(max_size + 1)
x = np.random.randint(int_low, int_high, size=x_size)
y_size = np.random.randint(max_size + 1)
y = np.random.randint(int_low, int_high, size=y_size)
out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y]
if out_idx:
out, idx = map(list, zip(*out_idx))
else:
out = []
idx = []
self._testListDiff(list(x), list(y), out, idx)
def testFullyOverlapping(self):
self._testListDiff(x=[1, 2, 3, 4], y=[1, 2, 3, 4], out=[], idx=[])
def testNonOverlapping(self):
self._testListDiff(x=[1, 2, 3, 4],
y=[5, 6],
out=[1, 2, 3, 4],
idx=[0, 1, 2, 3])
def testEmptyX(self):
self._testListDiff(x=[], y=[1, 2], out=[], idx=[])
def testEmptyY(self):
self._testListDiff(x=[1, 2, 3, 4], y=[], out=[1, 2, 3, 4], idx=[0, 1, 2, 3])
def testEmptyXY(self):
self._testListDiff(x=[], y=[], out=[], idx=[])
if __name__ == "__main__":
test.main()

View File

@ -22,6 +22,8 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
@ -42,20 +44,33 @@ class OutOfMemoryTest(xla_test.XLATestCase):
"""
def test_loop():
size = 2e8
size = int(2e8)
while True:
with self.test_session():
# Force the compiled code to not be constant by feeding in an addend.
p = array_ops.placeholder(dtypes.float32, shape=[])
# Force the compiled code to not be constant by feeding in a
# parameter.
p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1])
with self.test_scope():
# Create a large R1 tensor.
c = array_ops.zeros([size, 1]) + p
# Create a computation that produces a large R1 tensor as an
# intermediate result. Reduce it down so that if this file was
# compiled without --config=cuda, we don't force a D2H copy of a
# large tensor and potentially OOM the host.
#
# This is a bit tricky because XLA:GPU doesn't currently support RNG
# ops. Here we rely on the fact that XLA doesn't do algebraic
# simplifications on conv(<ones>, <filter>).
c = math_ops.reduce_sum(
nn_ops.convolution(
array_ops.ones([1, size, 1]),
p,
padding='SAME',
data_format='NWC'))
c.eval(feed_dict={p: 1.0})
c.eval(feed_dict={p: [[[1.0]], [[2.0]]]})
size *= 2
self.assertRaises(errors.ResourceExhaustedError, test_loop)
if __name__ == "__main__":
if __name__ == '__main__':
googletest.main()

View File

@ -0,0 +1,48 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for xla handling of placeholder_with_default."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
class PlaceholderTest(XLATestCase):
def test_placeholder_with_default_default(self):
with self.test_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(4.0)
ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2
sess.run(variables.variables_initializer([v]))
self.assertEqual(8.0, sess.run(out))
def test_placeholder_with_default_fed(self):
with self.test_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(4.0)
ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2
sess.run(variables.variables_initializer([v]))
self.assertEqual(2.0, sess.run(out, {ph: 1.0}))
if __name__ == '__main__':
googletest.main()

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import functools
import itertools
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
@ -155,5 +156,68 @@ class ReduceOpsTest(XLATestCase):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
class ReduceOpPrecisionTest(XLATestCase):
def _testReduceSum(self,
expected_result,
dtype,
test_inputs,
rtol=1e-3,
atol=1e-4):
"""Tests reduce sum on a list of input arrays.
For each array in test_inputs, check that performing reduce sum on the array
produces a value that is close to the expected result.
Args:
expected_result: the expected result.
dtype: the data type of the reduce sum operation.
test_inputs: a list of input arrays for the reduce sum operation.
rtol: the relative error.
atol: the absolute error.
"""
for test_input in test_inputs:
with self.test_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtype)
index = array_ops.placeholder(dtypes.int32)
out = math_ops.reduce_sum(a, index)
result = sess.run(out, {
a: np.array(test_input, dtype=dtype),
index: [0]
})
# Compare the results using float32 type.
self.assertAllClose(
np.float32(result),
np.float32(expected_result),
rtol=rtol,
atol=atol)
def testReduceSumF16(self):
"""Tests the reduce sum of float16 doesn't lose too much precision."""
if np.float16 not in self.all_types:
return
f16_max = np.finfo(np.float16).max
self._testReduceSum(
f16_max, np.float16,
itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3))
def testReduceSumBF16(self):
"""Tests the reduce sum of bfloat16 doesn't lose too much precision."""
if dtypes.bfloat16.as_numpy_dtype not in self.all_types:
return
bf16_max = np.float32(dtypes.bfloat16.max)
f32_max = dtypes.float32.max
value = min(bf16_max, f32_max - bf16_max)
self._testReduceSum(
dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype,
itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3))
if __name__ == '__main__':
googletest.main()

View File

@ -86,6 +86,15 @@ class StatelessRandomOpsTest(XLATestCase):
# seed were not fixed.
self.assertTrue(self._chi_squared(y, 10) < 16.92)
def testRandomNormalIsFinite(self):
with self.test_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless.stateless_random_uniform(
shape=[10000], seed=seed_t, dtype=dtype)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
self.assertTrue(np.all(np.isfinite(y)))
def _normal_cdf(self, x):
"""Cumulative distribution function for a standard normal distribution."""
return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))

View File

@ -472,7 +472,9 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1])
def testTensorArrayGradientWriteRead(self):
for dtype in self.numeric_types:
for dtype in self.float_types:
self._testTensorArrayGradientWriteReadType(dtype)
for dtype in self.complex_types:
self._testTensorArrayGradientWriteReadType(dtype)
def _testTensorArrayGradientWritePackConcatAndRead(self):

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@ -68,40 +69,41 @@ class TernaryOpsTest(XLATestCase):
expected=np.array([1, 3, 5], dtype=np.int32))
def testSelect(self):
self._testTernary(
array_ops.where,
np.array(0, dtype=np.bool),
np.array(2, dtype=np.float32),
np.array(7, dtype=np.float32),
expected=np.array(7, dtype=np.float32))
for dtype in self.numeric_types:
self._testTernary(
array_ops.where,
np.array(0, dtype=np.bool),
np.array(2, dtype=dtype),
np.array(7, dtype=dtype),
expected=np.array(7, dtype=dtype))
self._testTernary(
array_ops.where,
np.array(1, dtype=np.bool),
np.array([1, 2, 3, 4], dtype=np.float32),
np.array([5, 6, 7, 8], dtype=np.float32),
expected=np.array([1, 2, 3, 4], dtype=np.float32))
self._testTernary(
array_ops.where,
np.array(1, dtype=np.bool),
np.array([1, 2, 3, 4], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([1, 2, 3, 4], dtype=dtype))
self._testTernary(
array_ops.where,
np.array(0, dtype=np.bool),
np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32))
self._testTernary(
array_ops.where,
np.array(0, dtype=np.bool),
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
self._testTernary(
array_ops.where,
np.array([0, 1, 1, 0], dtype=np.bool),
np.array([1, 2, 3, 4], dtype=np.float32),
np.array([5, 6, 7, 8], dtype=np.float32),
expected=np.array([5, 2, 3, 8], dtype=np.float32))
self._testTernary(
array_ops.where,
np.array([0, 1, 1, 0], dtype=np.bool),
np.array([1, 2, 3, 4], dtype=dtype),
np.array([5, 6, 7, 8], dtype=dtype),
expected=np.array([5, 2, 3, 8], dtype=dtype))
self._testTernary(
array_ops.where,
np.array([0, 1, 0], dtype=np.bool),
np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32))
self._testTernary(
array_ops.where,
np.array([0, 1, 0], dtype=np.bool),
np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype))
def testSlice(self):
for dtype in self.numeric_types:
@ -119,6 +121,23 @@ class TernaryOpsTest(XLATestCase):
np.array([2, 1], dtype=np.int32),
expected=np.array([[2], [5]], dtype=dtype))
def testClipByValue(self):
# TODO(b/78258593): enable integer types here too.
for dtype in self.float_types:
test_cases = [
(np.array([2, 4, 5], dtype=dtype), dtype(7)), #
(dtype(1), np.array([2, 4, 5], dtype=dtype)), #
(np.array([-2, 7, 7], dtype=dtype), np.array([-2, 9, 8], dtype=dtype))
]
x = np.array([-2, 10, 6], dtype=dtype)
for lower, upper in test_cases:
self._testTernary(
gen_math_ops._clip_by_value,
x,
lower,
upper,
expected=np.minimum(np.maximum(x, lower), upper))
if __name__ == "__main__":
googletest.main()

View File

@ -209,7 +209,8 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.expm1,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype))
expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype),
rtol=1e-5)
self._assertOpOutputMatchesExpected(
math_ops.floor,
@ -251,12 +252,12 @@ class UnaryOpsTest(XLATestCase):
np.array([[1, 2]], dtype=dtype),
expected=np.array([[0.540297, -0.41614]], dtype=dtype))
# TODO(b/34703906): improve log1p implementation and make tolerance
# tighter.
self._assertOpOutputMatchesExpected(
math_ops.log1p,
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)))
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)),
rtol=1e-4,
atol=1e-6)
self._assertOpOutputMatchesExpected(
math_ops.rint,
@ -333,13 +334,19 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
nn_ops.elu,
np.array([[-1, 0, 1]], dtype=dtype),
expected=np.array([[-0.63212056, 0, 1]], dtype=dtype))
np.array([[-1, 0, 1, -1e-6]], dtype=dtype),
expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype),
rtol=1e-5,
atol=1e-6)
self._assertOpOutputMatchesExpected(
nn_ops.selu,
np.array([[-1, 0, 1]], dtype=dtype),
expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype))
np.array([[-1, 0, 1, -1e-5]], dtype=dtype),
expected=np.array(
[[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]],
dtype=dtype),
rtol=1e-5,
atol=1e-6)
self._assertOpOutputMatchesExpected(
nn_ops.relu,
@ -419,7 +426,9 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.expm1,
np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)))
expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)),
rtol=1e-6,
atol=1e-6)
self._assertOpOutputMatchesExpected(
math_ops.reciprocal,
@ -441,13 +450,13 @@ class UnaryOpsTest(XLATestCase):
np.array([[5j, 3 - 2j]], dtype=dtype),
expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype)))
# TODO(b/34703906): improve log1p implementation and make tolerance
# tighter.
self._assertOpOutputMatchesExpected(
math_ops.log1p,
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype),
expected=np.log1p(
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)),
rtol=1e-4,
atol=1e-6)
val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)
self._assertOpOutputMatchesExpected(
@ -789,7 +798,9 @@ class UnaryOpsTest(XLATestCase):
zero = np.asarray(0).astype(dtype)
expected = np.logaddexp(zero, features)
self._assertOpOutputMatchesExpected(
nn_ops.softplus, features, expected=expected)
nn_ops.softplus, features, expected=expected,
rtol=1e-6,
atol=9.1e-6)
def testSoftplus(self):
for dtype in self.float_types:

View File

@ -0,0 +1,48 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for XLA devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class XlaDeviceGpuTest(test.TestCase):
def testCopiesToAndFromGpuWork(self):
"""Tests that copies between GPU and XLA devices work."""
if not test.is_gpu_available():
return
with session_lib.Session() as sess:
x = array_ops.placeholder(dtypes.float32, [2])
with ops.device("GPU"):
y = x * 2
with ops.device("device:XLA_CPU:0"):
z = y * y
with ops.device("GPU"):
w = y + z
result = sess.run(w, {x: [1.5, 0.5]})
self.assertAllClose(result, [12., 2.], rtol=1e-3)
if __name__ == "__main__":
test.main()

View File

@ -1,4 +1,4 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,30 +18,33 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class XlaDeviceTest(test.TestCase):
class XlaDeviceTest(XLATestCase):
def testCopies(self):
"""Tests that copies between GPU and XLA devices work."""
if not test.is_gpu_available():
return
"""Tests that copies onto and off XLA devices work."""
shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],
[16384, 1], [1, 16384], [1, 20000, 1, 1]]
for dtype in self.numeric_types:
for shape in shapes:
with self.test_session() as sess:
with ops.device("CPU"):
x = array_ops.placeholder(dtype, shape)
with self.test_scope():
y = x + x
with ops.device("CPU"):
z = array_ops.identity(y)
with session_lib.Session() as sess:
x = array_ops.placeholder(dtypes.float32, [2])
with ops.device("GPU"):
y = x * 2
with ops.device("device:XLA_CPU:0"):
z = y * y
with ops.device("GPU"):
w = y + z
result = sess.run(w, {x: [1.5, 0.5]})
self.assertAllClose(result, [12., 2.], rtol=1e-3)
inputs = np.random.randint(-100, 100, shape).astype(dtype)
result = sess.run(z, {x: inputs})
self.assertAllCloseAccordingToType(result, inputs + inputs)
if __name__ == "__main__":

View File

@ -81,7 +81,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@ -168,9 +168,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@ -215,7 +215,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
@ -326,6 +325,7 @@ tf_cc_test(
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:cpu_plugin",
@ -412,7 +412,6 @@ cc_library(
hdrs = ["functionalize_control_flow.h"],
deps = [
":tf2xla_util",
"//tensorflow/compiler/jit:graph_to_functiondef",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla/ops:xla_ops",

View File

@ -21,13 +21,13 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
@ -282,7 +282,58 @@ Status BuildLoopBody(const Graph& graph, Frame* frame,
return Status::OK();
}
Status FunctionalizeLoop(Graph* graph, Frame* frame,
// Copy the FunctionDef of given function from lookup_library to library, if
// it can be found in lookup_library but is missing from library.
Status AddMissingFunctionByName(const string& function_name,
const FunctionLibraryDefinition* lookup_library,
FunctionLibraryDefinition* library) {
if (!library->Find(function_name) && lookup_library->Find(function_name)) {
return library->AddFunctionDef(*lookup_library->Find(function_name));
}
return Status::OK();
}
// Iterate over all functions that the given fdef refers to. Copy the missing
// FunctionDefs from lookup_library to library.
Status AddMissingFunctionDef(const FunctionDef& fdef,
const FunctionLibraryDefinition* lookup_library,
FunctionLibraryDefinition* library) {
TF_RET_CHECK(lookup_library);
for (const NodeDef& node : fdef.node_def()) {
if (library->Find(node.op())) {
continue;
}
// The function refered by 'SymbolicGradient' node is specified in its
// attribute 'f'.
if (node.op() == FunctionLibraryDefinition::kGradientOp) {
const AttrValue* attr =
AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr);
if (!attr) {
return errors::InvalidArgument("SymbolicGradient is missing attr: f");
}
const string& func_name = attr->func().name();
TF_RETURN_IF_ERROR(
AddMissingFunctionByName(func_name, lookup_library, library));
// Copy the user-defined gradient function if it exists.
const string grad_name = lookup_library->FindGradient(func_name);
if (!grad_name.empty() && library->FindGradient(func_name).empty()) {
TF_RETURN_IF_ERROR(
AddMissingFunctionByName(grad_name, lookup_library, library));
GradientDef grad_def;
grad_def.set_function_name(func_name);
grad_def.set_gradient_func(grad_name);
TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def));
}
} else if (lookup_library->Find(node.op())) {
TF_RETURN_IF_ERROR(
library->AddFunctionDef(*lookup_library->Find(node.op())));
}
}
return Status::OK();
}
Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
Graph* graph, Frame* frame,
FunctionLibraryDefinition* library) {
VLOG(2) << "Frame " << frame->name << " before: "
<< dump_graph::DumpGraphToFile("functionalize_before", *graph,
@ -489,6 +540,14 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
if (lookup_library) {
// Copy missing FunctionDefs from lookup_library to library to make library
// self-contained.
TF_RETURN_IF_ERROR(
AddMissingFunctionDef(cond_fdef, lookup_library, library));
TF_RETURN_IF_ERROR(
AddMissingFunctionDef(body_fdef, lookup_library, library));
}
// Builds a While operator.
NodeDef while_def;
@ -870,6 +929,9 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() {
// Merge the inputs of the switch node with one another. This results in
// predicates and control input residing in the same cluster.
for (const Edge* e : n->in_edges()) {
// Only consider the data inputs to the Switch node.
if (e->IsControlEdge()) continue;
Node* src = e->src();
UnionFind<Cluster>* src_cluster = find_output_cluster(src);
int src_cluster_depth = switch_depth[src_cluster->Get().representative];
@ -1362,6 +1424,12 @@ Status FunctionalizeCond::Functionalize(Graph* graph,
// functional equivalents.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library) {
return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
}
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "FunctionalizeControlFlow (initial): "
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph,
library);
@ -1431,7 +1499,8 @@ Status FunctionalizeControlFlow(Graph* graph,
continue;
}
TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));
TF_RETURN_IF_ERROR(
FunctionalizeLoop(lookup_library, graph, frame, library));
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;

View File

@ -22,9 +22,13 @@ limitations under the License.
namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
// operators, suitable for XLA compilation.
// operators, suitable for XLA compilation. If lookup_library is provided, use
// it to make the library for control flow self-contained.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library);
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library);
} // namespace tensorflow

View File

@ -299,6 +299,131 @@ TEST(FunctionalizeControlFlow, OneLoopVar) {
}
}
// @function.Defun(noinline=True)
// def increment_fn(x):
// return [x + 1]
// Define the above function, and add it to the given graph. It's used as the
// while loop body in NoinlineLoopBody test.
Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) {
FunctionDef fdef = FunctionDefHelper::Create(
"increment_fn", {"x:int32"}, {"add:int32"}, {},
{
{{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}},
{{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}},
},
{{"add", "add_0:z:0"}});
(*fdef.mutable_attr())["_noinline"].set_b(true);
FunctionDefLibrary fdef_lib;
*(fdef_lib.add_function()) = fdef;
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib));
NodeDef increment_fn;
increment_fn.set_name(node_name);
increment_fn.set_op("increment_fn");
*increment_fn.add_input() = "while/Identity";
*increment_fn.add_input() = "^while/Identity";
Status status;
graph->AddNode(increment_fn, &status);
return status;
}
// Graph:
// x = array_ops.placeholder(dtypes.int32)
// y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x])
TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
const string& noinline_node_name = "while/increment_fn";
Graph graph(OpRegistry::Global());
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source,
"while/while_context");
auto merge = ops::Merge(scope.WithOpName("while/Merge"),
std::initializer_list<Input>{enter, dummy});
auto ten = ops::Const<int32>(
scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
10);
auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
auto switch_ =
ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
switch_.output_false);
auto identity =
ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
NodeDef next_iter;
next_iter.set_name("while/NextIteration");
next_iter.set_op("NextIteration");
*next_iter.add_input() = noinline_node_name;
(*next_iter.mutable_attr())["T"].set_type(DT_INT32);
Status status;
Node* n = scope.graph()->AddNode(next_iter, &status);
TF_ASSERT_OK(status);
// Remove the dummy node and add the loop backedge.
scope.graph()->RemoveNode(dummy.node());
scope.graph()->AddEdge(n, 0, merge.output.node(), 1);
TF_ASSERT_OK(scope.ToGraph(&graph));
}
FunctionLibraryDefinition lookup_lib(graph.flib_def());
FunctionLibraryDefinition library(OpRegistry::Global(), {});
// Function increment_fn will be copied from lookup_lib to library.
TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library));
GraphDef graph_def;
graph.ToGraphDef(&graph_def);
NameAttrList cond_fn, body_fn;
TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
// Outer graph
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
ops::XlaWhile(scope.WithOpName("while/LoopCond"),
std::initializer_list<Input>{source}, cond_fn, body_fn);
GraphDef expected;
TF_ASSERT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
// Body graph.
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
NodeDef retval;
retval.set_name("_retval0_RetVal");
retval.set_op(FunctionLibraryDefinition::kRetOp);
*retval.add_input() = noinline_node_name;
(*retval.mutable_attr())["T"].set_type(DT_INT32);
(*retval.mutable_attr())["index"].set_i(0);
Status status;
scope.graph()->AddNode(retval, &status);
TF_ASSERT_OK(status);
GraphDef expected;
TF_ASSERT_OK(scope.ToGraphDef(&expected));
InstantiationResultForTest result;
// Verify that increment_fn has been copied to library.
TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
// Ignore the function library when comparing the graphs.
expected.clear_library();
TF_EXPECT_GRAPH_EQ(expected, result.gdef);
}
}
// Tests functionalizing OneLoopVar where the loop value is not used post the
// loop.
// Graph:

View File

@ -51,6 +51,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
const std::vector<const XlaExpression*>& expressions,
std::vector<XlaCompiler::Argument>* args) {
auto builder = ctx->builder();
auto client = ctx->compiler()->client();
std::vector<bool> compile_time_constant_flags(expressions.size());
TF_RETURN_IF_ERROR(
@ -72,8 +73,10 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
arg.kind = XlaCompiler::Argument::kConstant;
TF_RET_CHECK(expressions[i]->resource() == nullptr)
<< "Input with resource is not yet implemented.";
TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph(
expressions[i]->handle()));
TF_ASSIGN_OR_RETURN(auto literal,
builder->ComputeConstant(expressions[i]->handle()));
client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
} else {
@ -205,14 +208,15 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
TF_RETURN_IF_ERROR(
PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = false;
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(),
func, arguments, &result));
TF_RETURN_IF_ERROR(
compiler->CompileFunction(compile_options, func, arguments, &result));
TF_RET_CHECK(arguments.size() == expressions.size());
std::vector<xla::ComputationDataHandle> handles;
std::vector<xla::XlaOp> handles;
for (int64 i = 0; i < expressions.size(); ++i) {
if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
continue;
@ -226,11 +230,14 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
auto output_handle = b->Call(*result.computation, handles);
// The output handle of `Call` computation is a tuple type. Unzip it so
// that it can fit into future computations.
int computation_output = 0;
for (int64 i = 0; i < n->num_outputs(); ++i) {
if (result.outputs[i].is_constant) {
xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
} else {
xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i));
xla_op_context.SetOutput(
i, b->GetTupleElement(output_handle, computation_output));
++computation_output;
}
}
return b->first_error();

View File

@ -21,6 +21,7 @@ tf_kernel_library(
"cast_op.cc",
"categorical_op.cc",
"cholesky_op.cc",
"clip_by_value_op.cc",
"concat_op.cc",
"const_op.cc",
"conv_ops.cc",
@ -44,6 +45,7 @@ tf_kernel_library(
"image_resize_ops.cc",
"index_ops.cc",
"l2loss_op.cc",
"listdiff_op.cc",
"lrn_ops.cc",
"matmul_op.cc",
"matrix_band_part_op.cc",
@ -113,8 +115,8 @@ tf_kernel_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
@ -150,7 +152,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -166,7 +168,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -202,8 +204,8 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:argmax_op",

View File

@ -29,7 +29,7 @@ class AddNOp : public XlaOpKernel {
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("AddN requires at least one argument"));
xla::ComputationDataHandle sum = ctx->Input(0);
xla::XlaOp sum = ctx->Input(0);
for (int i = 1; i < ctx->num_inputs(); ++i) {
sum = ctx->builder()->Add(sum, ctx->Input(i));
}

View File

@ -48,9 +48,9 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(ctx->input_type(1), &scale_type));
xla::ComputationBuilder* builder = ctx->builder();
xla::XlaBuilder* builder = ctx->builder();
xla::ComputationDataHandle input = ctx->Input(0);
xla::XlaOp input = ctx->Input(0);
TensorShape input_shape = ctx->InputShape(0);
int feature_index =
@ -62,7 +62,7 @@ class FusedBatchNormOp : public XlaOpKernel {
input = builder->ConvertElementType(input, scale_type);
if (is_training_) {
xla::ComputationDataHandle output = builder->BatchNormTraining(
xla::XlaOp output = builder->BatchNormTraining(
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
// In training mode, outputs the normalized value as well as the
@ -79,7 +79,7 @@ class FusedBatchNormOp : public XlaOpKernel {
ctx->SetOutput(3, builder->GetTupleElement(output, 1));
ctx->SetOutput(4, builder->GetTupleElement(output, 2));
} else {
xla::ComputationDataHandle output = builder->BatchNormInference(
xla::XlaOp output = builder->BatchNormInference(
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
epsilon_, feature_index);
ctx->SetOutput(0, builder->ConvertElementType(output, input_type));
@ -118,7 +118,7 @@ class FusedBatchNormGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* const b = ctx->builder();
xla::XlaBuilder* const b = ctx->builder();
DataType input_dtype = ctx->input_type(0);
DataType scale_dtype = ctx->input_type(2);
@ -137,11 +137,11 @@ class FusedBatchNormGradOp : public XlaOpKernel {
const int feature_index =
GetTensorFeatureDimIndex(input_dims, data_format_);
xla::ComputationDataHandle x_backprop;
xla::ComputationDataHandle scale_backprop;
xla::ComputationDataHandle offset_backprop;
xla::XlaOp x_backprop;
xla::XlaOp scale_backprop;
xla::XlaOp offset_backprop;
if (is_training_) {
xla::ComputationDataHandle output =
xla::XlaOp output =
b->BatchNormGrad(activations, scale, mean, var, grad_backprop,
epsilon_, feature_index);

View File

@ -20,9 +20,8 @@ limitations under the License.
namespace tensorflow {
namespace {
void BatchToSpace(XlaOpKernelContext* ctx,
const xla::ComputationDataHandle& input, DataType input_dtype,
const TensorShape& input_tensor_shape,
void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
DataType input_dtype, const TensorShape& input_tensor_shape,
gtl::ArraySlice<int64> block_shape,
const xla::Literal& crops) {
const int input_rank = input_tensor_shape.dims();
@ -46,7 +45,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
", 2] instead of ",
xla::ShapeUtil::HumanString(crops.shape())));
xla::ComputationBuilder* b = ctx->builder();
xla::XlaBuilder* b = ctx->builder();
const int64 batch_size = input_shape[0];
// Compute the product of the block_shape values.
@ -73,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
reshaped_shape[block_rank] = batch_size / block_num_elems;
std::copy(input_shape.begin() + 1, input_shape.end(),
reshaped_shape.begin() + block_rank + 1);
xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce `permuted` of shape
// [batch / prod(block_shape),
@ -91,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
}
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation);
xla::XlaOp permuted = b->Transpose(reshaped, permutation);
// 3. Reshape `permuted` to produce `reshaped_permuted` of shape
// [batch / prod(block_shape),
@ -111,8 +110,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_permuted_shape.begin() + 1 + block_rank);
xla::ComputationDataHandle reshaped_permuted =
b->Reshape(permuted, reshaped_permuted_shape);
xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape);
// 4. Crop the start and end of dimensions `[1, ..., M]` of
// `reshaped_permuted` according to `crops` to produce the output of shape:
@ -139,7 +137,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
"Cropped size must be non-negative: start: ", crop_start,
" end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
}
xla::ComputationDataHandle output =
xla::XlaOp output =
b->Slice(reshaped_permuted, start_indices, end_indices, strides);
ctx->SetOutput(0, output);
}

Some files were not shown because too many files have changed in this diff Show More