diff --git a/.gitignore b/.gitignore
index be75938ec40..828bbe9bd33 100644
--- a/.gitignore
+++ b/.gitignore
@@ -27,6 +27,7 @@ Podfile.lock
/tensorflow/contrib/lite/examples/ios/simple/data/*.txt
/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite
xcuserdata/**
+/api_init_files_list.txt
# Android
.gradle
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 3dad41a88c8..8669c25c452 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,5 +1,16 @@
# Contributing guidelines
+## Pull Request Checklist
+
+Before sending your pull requests, make sure you followed this list.
+
+- Read [contributing guidelines](CONTRIBUTING.md).
+- Read [Code of Conduct](CODE_OF_CONDUCT.md).
+- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/).
+- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution).
+- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style).
+- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests).
+
## How to become a contributor and submit your own code
### Contributor License Agreements
diff --git a/README.md b/README.md
index e1a50c87e26..6fb4486d0de 100644
--- a/README.md
+++ b/README.md
@@ -5,9 +5,9 @@
-----------------
-| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
-|-----------------|---------------------|------------------|-------------------|---------------|---------------|
-| [](https://www.tensorflow.org/api_docs/) |  |  |  | [](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [](https://ci.tensorflow.org/job/tensorflow-master-android) [  ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
+| **`Documentation`** |
+|-----------------|
+| [](https://www.tensorflow.org/api_docs/) |
**TensorFlow** is an open source software library for numerical computation using
data flow graphs. The graph nodes represent mathematical operations, while
@@ -40,15 +40,6 @@ environment to install the nightly TensorFlow build. We support CPU and GPU
packages on Linux, Mac, and Windows.
-**Individual whl files**
-* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/))
-* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/))
-* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
-* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
-* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/))
-* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
-([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
-
#### *Try your first TensorFlow program*
```shell
$ python
@@ -82,6 +73,30 @@ The TensorFlow project strives to abide by generally accepted best practices in
[](https://bestpractices.coreinfrastructure.org/projects/1486)
+
+## Continuous build status
+
+### Official Builds
+
+| Build Type | Status | Artifacts |
+| --- | --- | --- |
+| **Linux CPU** |  | [pypi](https://pypi.org/project/tf-nightly/) |
+| **Linux GPU** |  | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
+| **Linux XLA** | TBA | TBA |
+| **MacOS** |  | [pypi](https://pypi.org/project/tf-nightly/) |
+| **Windows CPU** | [](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) |
+| **Windows GPU** | [](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
+| **Android** | [](https://ci.tensorflow.org/job/tensorflow-master-android) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) |
+
+
+### Community Supported Builds
+
+| Build Type | Status | Artifacts |
+| --- | --- | --- |
+| **IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA |
+| **IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
+
+
## For more information
* [TensorFlow Website](https://www.tensorflow.org)
diff --git a/RELEASE.md b/RELEASE.md
index e8459531748..84d9d52868e 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,3 +1,62 @@
+# Release 1.8.0
+
+## Major Features And Improvements
+* Can now pass `tf.contrib.distribute.MirroredStrategy()` to `tf.estimator.RunConfig()` to run an Estimator model on multiple GPUs on one machine.
+* Add `tf.contrib.data.prefetch_to_device()`, which supports prefetching to GPU memory.
+* Added Gradient Boosted Trees as pre-made Estimators: BoostedTreesClassifier, BoostedTreesRegressor.
+* Add 3rd generation pipeline config for Cloud TPUs which improves performance and usability.
+* `tf.contrib.bayesflow` is moving out to it's own repo.
+* Added `tf.contrib.{proto,rpc}` to allow generic proto parsing and RPC communication[1](#rpc-issue).
+
+## Bug Fixes and Other Changes
+* `tf.data`:
+ * Add `tf.contrib.data.prefetch_to_device`, which enables prefetching dataset elements to GPU memory.
+ * Add `tf.contrib.data.AUTOTUNE`, which allows the tf.data runtime to automatically tune the prefetch buffer sizes based on your system and environment.
+ * Add `tf.contrib.data.make_csv_dataset` for building datasets of CSV files.
+* Eager Execution:
+ * With eager execution Datasets can now be used as standard python iterators (`for batch in dataset:`). Both `Dataset.__iter__()` and `Dataset.make_one_shot_iterator()` can now be used to create iterators when eager execution is enabled.
+ * Automatic device placement has been enabled (i.e., use a GPU if available automatically, without requiring an explicit `with tf.device(“/gpu:0”)`) (Fixes #14133)
+ * `tf.GradientTape` has moved out of contrib.
+* `tf.keras`:
+ * Added the fashion mnist dataset.
+ * New data preprocessing functions: `image/random_brightness`, `sequence/TimeseriesGenerator`, and `text/hashing_trick`.
+* Accelerated Linear Algebra (XLA):
+ * Select and scatter in reference util and evaluator now use lexicographical order to break ties.
+* TensorFlow Debugger (tfdbg) CLI:
+ * During tensor-filter operations, allow exclusion of nodes by regular expressions.
+ * Fix spurious background colors in some text terminals.
+* `tf.contrib`:
+ * Add meta-distribution BatchReshape which reshapes batch dimensions.
+ * `tf.contrib.layers.recompute_grad` works for explicit gradient checkpointing on TPU.
+ * Add `tf.contrib.framework.argsort`.
+ * Allow `DNNBoostedTreeCombinedEstimator` to work with core versions of feature columns and losses.
+ * Add non-linear image warping ops: `tf.contrib.image.sparse_image_warp`, `tf.contrib.image.dense_image_warp`, and `tf.contrib.image.interpolate_spline`.
+ * Fix bug in `tf.contrib.opt.MultitaskOptimizerWrapper` where types of tensors were mismatched.
+* Other:
+ * Low-level graph construction now calls the TensorFlow C API. This change should be invisible to most users, but can be disabled by setting the environment variable `TF_C_API_GRAPH_CONSTRUCTION=0` in this release. Future releases will remove the ability to disable this change. Please [file a bug](https://github.com/tensorflow/tensorflow/issues/new) if you find yourself using this escape hatch.
+ * Add description of shapes and a pointer to tutorial notebook in `tf.distributions.Distribution`.
+ * Update scatter operations:
+ * Add `tf.scatter_min` and `tf.scatter_max`
+ * Extend scatter operations to work with a scalar update parameter.
+ * Move cuDNN RNN ops to core for use in TensorFlow codebase only.
+ * Add `float64` support for `Conv2d`, `Conv2dBackpropInput`, and `Conv2dBackpropFilter`.
+ * Add `float64` support for `AvgPool`/`AvgPoolGrad`.
+ * Make graph name scope thread local so that they work correctly in multi-threaded environments.
+ * Update nsync synchronization library to avoid slow primitives on Linux.
+ * Removed need to put nsync/public on C include path when building custom ops.
+ * Add `tf.image.psnr`, `tf.image.ssim`, `tf.image.ssim_multiscale`, `tf.image.image_gradients`, `tf.image.sobel_edges`.
+ * Add links to https://js.tensorflow.org.
+ * Fix non-uniformity of orthogonal matrices.
+ * Fix bug where multi-image Estimator eval summaries were not displayed correctly.
+
+1 The cancellation logic of the RPC op contains a concurrency error. A fix has been submitted to master and will be part of the next release.
+
+## Thanks to our Contributors
+
+This release contains contributions from many people at Google, as well as:
+
+4d55397500, Aghasy, Alan Du, Alan Lee, Alan Yee, Alex Wiltschko, Animesh Karnewar, Ankit Gupta, Anton Matosov, Aris L, Ben Barsdell, Brent Yi, Brett Koonce, Carl Thomé, cbockman, Chikanaga Tomoyuki, Chris Tava, CéDric Deltheil, Dahan Gong, Dalmo Cirne, Daniel Erenrich, David Norman, DavidNorman, Edd Wilder-James, Fanjin Zeng, Felix Abecassis, fo40225, George Sterpu, Giovanni Terlingen, Gor Baghdasaryan, Guillaume Klein, Hanchen Li, Ilya Polenov, Jakub Kolodziejczyk, Jason Sadler, Jayaram Bobba, Jerry Liu, jinghuangintel, Jiongyan Zhang (张炯衍), Joel Shor, Jong Wook Kim, Julian Eisenschlos, Karl Lessard, Krish Ravindranath, Loo Rong Jie, Lukas Geiger, Luke Iwanski, Mahmoud Abuzaina, ManHyuk, Marvin Richter, Maximilian Mitchell, Mohammad Ashraf Bhuiyan, msofka, Mustafa Kasap, Nathan Burnham, Nathan Luehr, Naveen Marri, ngc92, nio1814, Oleg Zabluda, Ou Changkun, Panos Ipeirotis, Paul Van Eck, Peter Lee, Piotr Czapla, qjivy, Rholais Lii, Rodrigo Formigone, Russell Klopfer, ryantimjohn, Sang Han, SebastiáN RamíRez, shengfuintel, Siby Jose Plathottam, Silver Chan, Stanislaw Antol, Taehoon Lee, Tarang Chugh, Ted Chang, Thomas Bastiani, Xian Xu, Xiaoming (Jason) Cui, Yan Facai (颜发才), yaox12, Yashal Shakti Kanungo, Yong Tang, Yuan (Terry) Tang, Yuxin Wu, Ziyue(Louis) Lu
+
# Release 1.7.0
## Major Features And Improvements
@@ -177,7 +236,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田
* Add `complex64` support to XLA compiler.
* `bfloat` support is now added to XLA infrastructure.
* Make `ClusterSpec` propagation work with XLA devices.
- * Use a determinisitic executor to generate XLA graph.
+ * Use a deterministic executor to generate XLA graph.
* `tf.contrib`:
* `tf.contrib.distributions`:
* Add `tf.contrib.distributions.Autoregressive`.
diff --git a/SECURITY.md b/SECURITY.md
index a5ce3a62ee2..01886b613e5 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -173,7 +173,7 @@ the progress being made towards a fix and announcement.
In addition, please include the following information along with your report:
* Your name and affiliation (if any).
-* A description the technical details of the vulnerabilities. It is very
+* A description of the technical details of the vulnerabilities. It is very
important to let us know how we can reproduce your findings.
* An explanation who can exploit this vulnerability, and what they gain when
doing so -- write an attack scenario. This will help us evaluate your report
diff --git a/WORKSPACE b/WORKSPACE
index 11c5cdb2070..4ddfb9a3832 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -2,11 +2,11 @@ workspace(name = "org_tensorflow")
http_archive(
name = "io_bazel_rules_closure",
- sha256 = "6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657",
- strip_prefix = "rules_closure-08039ba8ca59f64248bb3b6ae016460fe9c9914f",
+ sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
+ strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz",
- "https://github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", # 2018-01-16
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz",
+ "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13
],
)
diff --git a/configure.py b/configure.py
index 8fb89791116..b6c32543cf7 100644
--- a/configure.py
+++ b/configure.py
@@ -226,8 +226,6 @@ def setup_python(environ_cp):
# Set-up env variables used by python_configure.bzl
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
- write_to_bazelrc('build --force_python=py%s' % python_major_version)
- write_to_bazelrc('build --host_force_python=py%s' % python_major_version)
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
@@ -500,10 +498,6 @@ def set_cc_opt_flags(environ_cp):
if not is_ppc64le() and not is_windows():
write_to_bazelrc('build:opt --host_copt=-march=native')
write_to_bazelrc('build:opt --define with_default_optimizations=true')
- # TODO(mikecase): Remove these default defines once we are able to get
- # TF Lite targets building without them.
- write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK')
- write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK')
def set_tf_cuda_clang(environ_cp):
"""set TF_CUDA_CLANG action_env.
@@ -847,8 +841,8 @@ def reformat_version_sequence(version_str, sequence_count):
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
ask_cuda_version = (
- 'Please specify the CUDA SDK version you want to use, '
- 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
+ 'Please specify the CUDA SDK version you want to use. '
+ '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
# Configure the Cuda SDK version to use.
@@ -1228,6 +1222,9 @@ def set_tf_cuda_compute_capabilities(environ_cp):
ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
# Check whether all capabilities from the input is valid
all_valid = True
+ # Remove all whitespace characters before splitting the string
+ # that users may insert by accident, as this will result in error
+ tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split())
for compute_capability in tf_cuda_compute_capabilities.split(','):
m = re.match('[0-9]+.[0-9]+', compute_capability)
if not m:
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 18eeb281680..b86b277ac32 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -2097,7 +2097,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
for (int i = 0; i < size; ++i) {
TensorId id = results.missing_unused_input_map_keys[i];
- tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
+ tf_results->missing_unused_key_names_data.push_back(std::string(id.first));
tf_results->missing_unused_key_names[i] =
tf_results->missing_unused_key_names_data.back().c_str();
tf_results->missing_unused_key_indexes[i] = id.second;
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 9678ee926fc..95b04f9058a 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -184,6 +184,7 @@ library {
return std::move(functions[0]);
}
+#if not defined(PLATFORM_WINDOWS)
// On success, returns a set of TF_Function instances encoding a dataset
// node stack that reads a Imagenet TFRecordFile dataset from `file_path`, and
// sets `dataset_name` to the created dataset name. The returned functions must
@@ -7076,7 +7077,9 @@ library {
return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status);
#endif
}
+#endif
+#if not defined(PLATFORM_WINDOWS)
// On success, returns a set of TF_Function instances encoding a dataset
// node stack that reads an MNIST file dataset from `file_path`, and
// sets `dataset_name` to the created dataset name. The returned functions must
@@ -8221,6 +8224,7 @@ library {
return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status);
#endif
}
+#endif
// Adds the input functions to `graph`. On success, returns the created
// IteratorGetNext node.
@@ -8314,6 +8318,13 @@ TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(TF_Graph* graph,
TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
TF_Graph* graph, const char* file_path, int batch_size,
unsigned char is_mnist, TF_Status* status) {
+#if defined(PLATFORM_WINDOWS)
+ // TODO(ashankar): get these functions working on Windows.
+ status->status = tensorflow::errors::Unimplemented(
+ "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API "
+ "is not implemented for Windows");
+ return nullptr;
+#else
tensorflow::Status s;
std::string dataset_name;
@@ -8355,4 +8366,92 @@ TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
<< graph->graph.ToGraphDefDebug().DebugString();
return getnext_node;
+#endif
+}
+
+TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
+ TF_Status* status) {
+ assert(session);
+ {
+ tensorflow::mutex_lock c(session->graph->mu);
+ VLOG(1) << "Dequeuing named tensor with id " << tensor_id
+ << ", with input graph: "
+ << session->graph->graph.ToGraphDefDebug().DebugString();
+ }
+
+ TF_Operation* dequeue_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
+ if (dequeue_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the dequeue node in the TF graph.");
+ return nullptr;
+ }
+
+ VLOG(1) << "Running the dequeue op";
+ TF_Output output{dequeue_op, 0};
+ TF_Tensor* ret;
+ TF_SessionRun(session, /*run_options*/ nullptr,
+ // input related parameters
+ /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
+ // output related parameters
+ /*outputs*/ &output, /*output_values*/ &ret,
+ /*noutputs*/ 1,
+ /*targets*/ nullptr, /*ntargets*/ 0,
+ /*run_metadata*/ nullptr, status);
+ if (VLOG_IS_ON(1) && status->status.ok()) {
+ tensorflow::Tensor tensor;
+ if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
+ VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
+ }
+ }
+ return ret;
+}
+
+void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
+ TF_Tensor* tensor, TF_Status* status) {
+ assert(session);
+ {
+ tensorflow::mutex_lock c(session->graph->mu);
+ if (VLOG_IS_ON(1)) {
+ VLOG(1) << "Enqueuing named tensor with id " << tensor_id
+ << ", with input graph: "
+ << session->graph->graph.ToGraphDefDebug().DebugString();
+ tensorflow::Tensor internal_tensor;
+ if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
+ VLOG(1) << "Enqueu'ing tensor content: "
+ << internal_tensor.DebugString();
+ }
+ }
+ }
+
+ TF_Operation* enqueue_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
+ if (enqueue_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the enqueue node in the TF graph.");
+ return;
+ }
+
+ TF_Operation* placeholder_op = TF_GraphOperationByName(
+ session->graph,
+ tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
+ if (placeholder_op == nullptr) {
+ status->status = tensorflow::errors::Internal(
+ "Unable to find the placeholder node as input to enqueue in the TF "
+ "graph.");
+ return;
+ }
+
+ VLOG(1) << "Running the enqueue op";
+ TF_Output input{placeholder_op, 0};
+ TF_SessionRun(session, /*run_options*/ nullptr,
+ // input related parameters
+ /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
+ // output related parameters
+ /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
+ /*targets*/ &enqueue_op, /*ntargets*/ 1,
+ /*run_metadata*/ nullptr, status);
+ VLOG(1) << "Enqueuing is done.";
}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 88cb173cd25..20bdace40f1 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -86,6 +86,35 @@ TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
TF_Graph* graph, const char* file_path, int batch_size,
unsigned char is_mnist, TF_Status* status);
+// On success, dequeues a tensor from a TF-managed FifoQueue given by
+// `tensor_id`, associated with `session`. There must be a graph node named
+// "fifo_queue_dequeue_", to be executed by this API call.
+
+// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is
+// empty, this call is blocked.
+//
+// Tensors are enqueued via the corresponding TF enqueue op.
+// TODO(hongm): Add support for `timeout_ms`.
+TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session,
+ int tensor_id,
+ TF_Status* status);
+
+// On success, enqueues `tensor` into a TF-managed FifoQueue given by
+// `tensor_id`, associated with `session`. There must be a graph node named
+// "fifo_queue_enqueue_", to be executed by this API call. It reads
+// from a placeholder node "arg_tensor_enqueue_".
+//
+// `tensor` is still owned by the caller. This call will be blocked if the queue
+// has reached its capacity, and will be unblocked when the queued tensors again
+// drop below the capacity due to dequeuing.
+//
+// Tensors are dequeued via the corresponding TF dequeue op.
+// TODO(hongm): Add support for `timeout_ms`.
+TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
+ int tensor_id,
+ TF_Tensor* tensor,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index ca80db23ed3..577f10c5e69 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1368,7 +1368,7 @@ TEST(CAPI, SavedModel) {
}
const tensorflow::string input_op_name =
- tensorflow::ParseTensorName(input_name).first.ToString();
+ std::string(tensorflow::ParseTensorName(input_name).first);
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
@@ -1376,7 +1376,7 @@ TEST(CAPI, SavedModel) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
const tensorflow::string output_op_name =
- tensorflow::ParseTensorName(output_name).first.ToString();
+ std::string(tensorflow::ParseTensorName(output_name).first);
TF_Operation* output_op =
TF_GraphOperationByName(graph, output_op_name.c_str());
ASSERT_TRUE(output_op != nullptr);
@@ -1700,7 +1700,7 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
-// REGISTER_OP for CApiTestAttributesTest test cases.
+// REGISTER_OP for CApiAttributesTest test cases.
// Registers two ops, each with a single attribute called 'v'.
// The attribute in one op will have a type 'type', the other
// will have list(type).
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index cd19cf8d624..c16aba666ee 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc
index b1f7bdaa542..74bc25a491a 100644
--- a/tensorflow/c/checkpoint_reader.cc
+++ b/tensorflow/c/checkpoint_reader.cc
@@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() {
const auto& slice_proto = entry.slices(i);
CHECK(filtered_keys
.insert(EncodeTensorNameSlice(
- v2_reader_->key().ToString() /* full var's name */,
+ std::string(v2_reader_->key()) /* full var's name */,
TensorSlice(slice_proto)))
.second);
}
@@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() {
new TensorSliceReader::VarToDataTypeMap);
v2_reader_->Seek(kHeaderEntryKey);
for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
- if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue;
+ if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue;
CHECK(entry.ParseFromArray(v2_reader_->value().data(),
v2_reader_->value().size()))
<< entry.InitializationErrorString();
- string key = v2_reader_->key().ToString();
+ string key = std::string(v2_reader_->key());
(*var_to_shape_map)[key] = TensorShape(entry.shape());
(*var_to_data_type_map)[key] = DataType(entry.dtype());
}
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index a2d96357ac8..9ce781fab0b 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -24,14 +24,13 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
- ":runtime",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
+ "//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
- "//tensorflow/core/common_runtime/eager:execute_node",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
@@ -49,6 +48,18 @@ tf_cuda_library(
],
"//conditions:default": [],
}) + [
+ "//tensorflow/core/common_runtime/eager:eager_operation",
+ "//tensorflow/core/distributed_runtime/eager:eager_client",
+ "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
+ "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:remote_device",
+ "//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core:gpu_runtime",
],
)
@@ -59,7 +70,6 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":c_api",
- ":runtime",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
@@ -69,10 +79,23 @@ tf_cuda_library(
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
+ "//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
+ "//tensorflow/core/distributed_runtime:remote_device",
+ "//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime/eager:eager_client",
+ "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
+ "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
],
)
@@ -91,47 +114,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- ],
-)
-
-tf_cuda_library(
- name = "runtime",
- srcs = ["runtime.cc"],
- hdrs = ["runtime.h"],
- copts = tf_copts(),
- visibility = ["//tensorflow:internal"],
- deps = select({
- "//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib_lite",
- ],
- "//conditions:default": [
- "//tensorflow/c:c_api",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core/common_runtime/eager:kernel_and_device",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- ],
- }),
-)
-
-tf_cc_test(
- name = "runtime_test",
- srcs = ["runtime_test.cc"],
- deps = [
- ":runtime",
- "//tensorflow/cc:cc_ops",
- "//tensorflow/cc:client_session",
- "//tensorflow/cc:ops",
- "//tensorflow/cc:scope",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
+ "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index c96a38dec3e..216210c88c1 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
-#include "tensorflow/c/eager/runtime.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
@@ -32,16 +31,22 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
-#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
@@ -72,6 +77,121 @@ string DeviceName(const tensorflow::Device* d) {
std::atomic_int_fast64_t func_id_generator(0);
#endif // TENSORFLOW_EAGER_USE_XLA
+tensorflow::Status GetAllRemoteDevices(
+ const std::vector& remote_workers,
+ tensorflow::WorkerCacheInterface* worker_cache,
+ std::unique_ptr* device_mgr) {
+ std::vector remote_devices;
+ tensorflow::Status status;
+ // TODO(nareshmodi) do this in parallel instead of serially.
+ for (const string& remote_worker : remote_workers) {
+ tensorflow::Notification n;
+ tensorflow::NewRemoteDevices(
+ tensorflow::Env::Default(), worker_cache, remote_worker,
+ [&status, &n, &remote_devices](
+ const tensorflow::Status& s,
+ std::vector* devices) {
+ status = s;
+ if (s.ok()) {
+ for (tensorflow::Device* d : *devices) {
+ remote_devices.push_back(d);
+ }
+ }
+ n.Notify();
+ });
+ n.WaitForNotification();
+ }
+ std::unique_ptr remote_device_mgr(
+ new tensorflow::DeviceMgr(remote_devices));
+
+ TF_RETURN_IF_ERROR(status);
+
+ *device_mgr = std::move(remote_device_mgr);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status CreateRemoteContexts(
+ const std::vector& remote_workers,
+ tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
+ tensorflow::gtl::FlatMap* remote_contexts) {
+ for (int i = 0; i < remote_workers.size(); i++) {
+ const string& remote_worker = remote_workers[i];
+
+ tensorflow::eager::CreateContextRequest request;
+ tensorflow::eager::CreateContextResponse response;
+ tensorflow::DeviceNameUtils::ParsedName parsed_name;
+ if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
+ &parsed_name)) {
+ return tensorflow::errors::InvalidArgument(
+ "Unable to parse ", remote_worker, " as a device name");
+ }
+ request.mutable_server_def()->set_job_name(parsed_name.job);
+ request.mutable_server_def()->set_task_index(parsed_name.task);
+ request.set_async(async);
+ auto* eager_client = remote_eager_workers->GetClient(remote_worker);
+ if (eager_client == nullptr) {
+ return tensorflow::errors::Internal(
+ "Cannot find a client for the given target:", remote_worker);
+ }
+ tensorflow::Notification n;
+ tensorflow::Status status;
+ // TODO(nareshmodi) do this in parallel instead of serially.
+ eager_client->CreateContextAsync(
+ &request, &response, [&status, &n](const tensorflow::Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ TF_RETURN_IF_ERROR(status);
+
+ remote_contexts->emplace(remote_worker, response.context_id());
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
+ TFE_Context** ctx) {
+ string worker_name = tensorflow::strings::StrCat(
+ "/job:", opts->server_def.job_name(),
+ "/replica:0/task:", opts->server_def.task_index());
+ std::unique_ptr server;
+ TF_RETURN_IF_ERROR(
+ tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server));
+
+ TF_RETURN_IF_ERROR(server->Start());
+
+ std::vector remote_workers;
+ server->master_env()->worker_cache->ListWorkers(&remote_workers);
+ remote_workers.erase(
+ std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
+ remote_workers.end());
+
+ std::unique_ptr remote_device_mgr;
+ TF_RETURN_IF_ERROR(GetAllRemoteDevices(
+ remote_workers, server->master_env()->worker_cache, &remote_device_mgr));
+
+ std::shared_ptr channel_cache =
+ server->channel_cache();
+ std::unique_ptr remote_eager_workers(
+ tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
+
+ // Initialize remote eager workers.
+ tensorflow::gtl::FlatMap remote_contexts;
+ TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers,
+ remote_eager_workers.get(),
+ opts->async, &remote_contexts));
+
+ tensorflow::RemoteRendezvous* r =
+ server->worker_env()->rendezvous_mgr->Find(0);
+
+ auto* device_mgr = server->worker_env()->device_mgr;
+ *ctx = new TFE_Context(opts->session_options.options, opts->policy,
+ opts->async, device_mgr, r, std::move(server),
+ std::move(remote_eager_workers),
+ std::move(remote_device_mgr), remote_contexts);
+
+ return tensorflow::Status::OK();
+}
} // namespace
extern "C" {
@@ -92,6 +212,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
options->policy = policy;
}
+TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
+ TFE_ContextOptions* options, const void* proto, size_t proto_len,
+ TF_Status* status) {
+ if (!options->server_def.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Invalid tensorflow.ServerDef protocol buffer");
+ }
+}
+
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
@@ -101,28 +230,35 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
+ if (!opts->server_def.job_name().empty()) {
+ TFE_Context* ctx = nullptr;
+ status->status = NewRemoteAwareTFE_Context(opts, &ctx);
+ return ctx;
+ }
+
std::vector devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices);
- if (!status->status.ok()) {
- return nullptr;
- }
+ if (!status->status.ok()) return nullptr;
std::unique_ptr device_mgr(
new tensorflow::DeviceMgr(devices));
+
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
+
return new TFE_Context(opts->session_options.options, opts->policy,
opts->async, std::move(device_mgr), r);
}
-void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
- delete ctx;
-}
+void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
- ctx->context.device_mgr()->ListDeviceAttributes(&list->response);
+ ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
+ if (ctx->context.remote_device_mgr()) {
+ ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
+ }
return list;
}
@@ -220,9 +356,6 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
}
return retval;
}
-} // extern "C"
-
-extern "C" {
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
@@ -242,21 +375,18 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
- tensorflow::Device* d = nullptr;
- if (device_name != nullptr && strlen(device_name) > 0) {
- status->status = op->ctx->context.FindDeviceByName(device_name, &d);
- }
- op->device = d;
+ status->status = op->operation.SetDevice(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
- tensorflow::Device* device =
- (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device;
+ tensorflow::Device* device = (op->operation.Device() == nullptr)
+ ? op->operation.EagerContext()->HostCPU()
+ : op->operation.Device();
return device->name().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
- op->use_xla = enable;
+ op->operation.SetUseXla(enable);
#ifndef TENSORFLOW_EAGER_USE_XLA
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
@@ -264,22 +394,20 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
- h->handle->Ref();
- op->inputs.push_back(h->handle);
- op->attrs.NumInputs(op->inputs.size());
+ op->operation.AddInput(h->handle);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret;
- if (op->is_function()) {
+ if (op->operation.is_function()) {
status->status = tensorflow::errors::Unimplemented(
"TODO(apassos): Support for attributes for TensorFlow functions is not "
"ready yet.");
return TF_ATTR_INT; // The compiler requires that we return something.
}
- status->status =
- tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
+ status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
+ attr_name, &ret, is_list);
return ret;
}
@@ -298,23 +426,24 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
}
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
- op->attrs.Set(attr_name, value);
+ op->operation.MutableAttrs()->Set(attr_name, value);
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
- op->attrs.Set(attr_name, static_cast(value));
+ op->operation.MutableAttrs()->Set(attr_name, static_cast(value));
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
- op->attrs.Set(attr_name, value);
+ op->operation.MutableAttrs()->Set(attr_name, value);
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
- op->attrs.Set(attr_name, (value == 0) ? false : true);
+ op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
- op->attrs.Set(attr_name, static_cast(value));
+ op->operation.MutableAttrs()->Set(attr_name,
+ static_cast(value));
}
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
@@ -336,23 +465,24 @@ void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
proto.add_dim()->set_size(dims[d]);
}
}
- op->attrs.Set(attr_name, proto);
+ op->operation.MutableAttrs()->Set(attr_name, proto);
}
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
- func->set_name(value->name);
- value->attrs.FillAttrValueMap(func->mutable_attr());
- op->attrs.Set(attr_name, attr_value);
+ func->set_name(value->operation.Name());
+ value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
+ op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
#define TFE_OP_SET_ATTR_LIST(fn, type) \
void fn(TFE_Op* op, const char* attr_name, const type* values, \
int num_values) { \
- op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice( \
- values, num_values)); \
+ op->operation.MutableAttrs()->Set( \
+ attr_name, \
+ tensorflow::gtl::ArraySlice(values, num_values)); \
}
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
@@ -360,14 +490,14 @@ TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
- op->attrs.Set(attr_name,
- tensorflow::gtl::ArraySlice(
- reinterpret_cast(values), num_values));
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice(
+ reinterpret_cast(values), num_values));
}
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
- op->attrs.Set(
+ op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::gtl::ArraySlice(
reinterpret_cast(values), num_values));
@@ -379,8 +509,8 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
- op->attrs.Set(attr_name,
- tensorflow::gtl::ArraySlice(b.get(), num_values));
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice(b.get(), num_values));
}
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
@@ -410,9 +540,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
}
}
}
- op->attrs.Set(attr_name,
- tensorflow::gtl::ArraySlice(
- proto.get(), num_values));
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice(
+ proto.get(), num_values));
}
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
@@ -420,534 +550,25 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
std::unique_ptr funcs(
new tensorflow::NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
- funcs[i].set_name(value[i]->name);
- value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr());
+ funcs[i].set_name(value[i]->operation.Name());
+ value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
}
- op->attrs.Set(attr_name,
- tensorflow::gtl::ArraySlice(
- funcs.get(), num_values));
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice(
+ funcs.get(), num_values));
}
-} // extern "C"
-
-namespace {
-
-// Initializes the step stats if needed.
-void MaybeInitializeStepStats(tensorflow::StepStats* step_stats,
- tensorflow::EagerContext* ctx) {
- // Lazily initialize the RunMetadata with information about all devices if
- // this is the first call.
- while (step_stats->dev_stats_size() < ctx->devices()->size()) {
- int device_idx = step_stats->dev_stats_size();
- auto* dev_stats = step_stats->add_dev_stats();
- dev_stats->set_device(ctx->devices()->at(device_idx)->name());
- }
-}
-
-int StepStatsDeviceIndex(tensorflow::StepStats* step_stats,
- tensorflow::EagerContext* ctx,
- tensorflow::Device* device) {
- // Find the current device's index.
- if (device == nullptr) {
- device = ctx->HostCPU();
- }
- for (int i = 0; i < ctx->devices()->size(); ++i) {
- if (ctx->devices()->at(i) == device ||
- ctx->devices()->at(i)->name() == device->name()) {
- return i;
- }
- }
- // TODO(apassos) do not fall back to host CPU if device is unknown.
- return 0;
-}
-
-tensorflow::Status ValidateInputTypeAndPlacement(
- tensorflow::EagerContext* ctx, tensorflow::Device* op_device, TFE_Op* op,
- const tensorflow::OpKernel* kernel, tensorflow::RunMetadata* run_metadata) {
- tensorflow::Device* host_device = ctx->HostCPU();
- const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
- if (memtypes.size() != op->inputs.size()) {
- return tensorflow::errors::InvalidArgument(
- "expected ", memtypes.size(), " inputs, got ", op->inputs.size());
- }
- for (int i = 0; i < op->inputs.size(); ++i) {
- const tensorflow::Device* expected_device =
- memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
- tensorflow::TensorHandle* handle = op->inputs[i];
- tensorflow::Device* handle_device = nullptr;
- TF_RETURN_IF_ERROR(handle->Device(&handle_device));
- const tensorflow::Device* actual_device =
- handle_device == nullptr ? host_device : handle_device;
- if (expected_device != actual_device) {
- switch (ctx->GetDevicePlacementPolicy()) {
- case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32:
- // TODO(xpan): See if we could bubble python related error up
- // to python level.
- if (handle->dtype == tensorflow::DT_INT32) {
- // Note: enabling silent copies of int32 tensors to match behavior
- // of graph mode.
- break;
- }
- TF_FALLTHROUGH_INTENDED;
- case tensorflow::DEVICE_PLACEMENT_EXPLICIT:
- return tensorflow::errors::InvalidArgument(
- "Tensors on conflicting devices:"
- " cannot compute ",
- op->name, " as input #", i, " was expected to be on ",
- expected_device->name(), " but is actually on ",
- actual_device->name(), " (operation running on ",
- op_device->name(), ")",
- " Tensors can be copied explicitly using .gpu() or .cpu() "
- "methods,"
- " or transparently copied by using tf.enable_eager_execution("
- "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
- "between devices"
- " may slow down your model");
- case tensorflow::DEVICE_PLACEMENT_WARN:
- LOG(WARNING) << "before computing " << op->name << " input #" << i
- << " was expected to be on " << expected_device->name()
- << " but is actually on " << actual_device->name()
- << " (operation running on " << op_device->name()
- << "). This triggers a copy which can be a performance "
- "bottleneck.";
- break;
- case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing.
- break;
- }
- // We are only here if the policy is warn or silent copies, so we should
- // trigger a copy.
- auto pre_time = tensorflow::Env::Default()->NowMicros();
- tensorflow::TensorHandle* copied_tensor = nullptr;
- tensorflow::Status status = tensorflow::EagerCopyToDevice(
- handle, ctx, expected_device->name().c_str(), &copied_tensor);
- if (run_metadata != nullptr) {
- auto* step_stats = run_metadata->mutable_step_stats();
- MaybeInitializeStepStats(step_stats, ctx);
- // Record the sending on the source device for now.
- int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
- auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
- auto* node_stats = dev_stats->add_node_stats();
- node_stats->set_node_name("_Send");
- node_stats->set_all_start_micros(pre_time);
- node_stats->set_op_end_rel_micros(
- tensorflow::Env::Default()->NowMicros() - pre_time);
- }
- if (!status.ok()) {
- if (copied_tensor != nullptr) copied_tensor->Unref();
- return tensorflow::errors::Internal(
- "Failed copying input tensor from ", actual_device->name(), " to ",
- expected_device->name(), " in order to run ", op->name, ": ",
- status.error_message());
- }
- handle->Unref();
- handle = copied_tensor;
- op->inputs[i] = copied_tensor;
- }
- if (handle->dtype != kernel->input_type(i)) {
- return tensorflow::errors::InvalidArgument(
- "cannot compute ", op->name, " as input #", i,
- " was expected to be a ",
- tensorflow::DataTypeString(kernel->input_type(i)),
- " tensor but is a ", tensorflow::DataTypeString(handle->dtype),
- " tensor");
- }
- }
- return tensorflow::Status::OK();
-}
-
-tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
- TFE_Context* ctx, TF_Status* status) {
- tensorflow::DeviceSet ds;
- for (tensorflow::Device* d : *ctx->context.devices()) {
- ds.AddDevice(d);
- }
- tensorflow::DeviceTypeVector final_devices;
- status->status = tensorflow::SupportedDeviceTypesForNode(
- ds.PrioritizedDeviceTypeList(), ndef, &final_devices);
- if (!status->status.ok()) {
- return nullptr;
- }
- if (final_devices.empty()) {
- status->status = tensorflow::errors::Internal(
- "Could not find valid device for node ", ndef.DebugString());
- return nullptr;
- }
- for (tensorflow::Device* d : *ctx->context.devices()) {
- if (d->device_type() == final_devices[0].type_string()) {
- return d;
- }
- }
- status->status = tensorflow::errors::Unknown(
- "Could not find a device for node ", ndef.DebugString());
- return nullptr;
-}
-
-
-#ifdef TENSORFLOW_EAGER_USE_XLA
-// Synthesizes and returns a wrapper function over `op`, which must be a
-// primitive op (e.g. matmul).
-//
-// The wrapper function conforms to the function signature expected by
-// _XlaLaunchOp, with input params ordered by . For example, if the op has input params , they will be reordered to as the input params to the synthesized function.
-//
-// It populates `const_input_types`, `arg_input_types` and
-// `op_input_to_func_input` based on the reordering results, that the caller can
-// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
-// `status` accordingly.
-const tensorflow::FunctionDef* OpToFunction(
- TFE_Op* op, std::vector* const_input_types,
- std::vector* arg_input_types,
- tensorflow::gtl::FlatMap* op_input_to_func_input,
- TF_Status* status) {
- DCHECK(!op->is_function());
-
- tensorflow::FunctionDef fdef;
-
- // Get the OpDef of the op we are trying to encapsulate.
- TFE_Context* ctx = op->ctx;
- const tensorflow::OpRegistrationData* op_data;
- {
- status->status = ctx->context.FindFunctionOpData(op->name, &op_data);
- if (!status->status.ok()) {
- return nullptr;
- }
- }
- const tensorflow::OpDef& op_def = op_data->op_def;
-
- tensorflow::OpDef* signature = fdef.mutable_signature();
-
- // Handle constant inputs.
- const std::unordered_set const_inputs(
- *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name));
-
- // First add place holders for the input args, so that we can refer to them by
- // position in the next loop. Also tally up the resource inputs.
- int num_resource_inputs = 0;
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) {
- ++num_resource_inputs;
- }
- signature->add_input_arg();
- }
-
- // Now we map the input params from `op_def` to `signature`, where the param
- // ordering for `signature` is: .
- int const_index = 0;
- int arg_index = const_inputs.size();
- int resource_index = op_def.input_arg_size() - num_resource_inputs;
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
- tensorflow::OpDef::ArgDef* func_input_arg = nullptr;
- if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
- VLOG(1) << "For const input, mapping op input " << i << " to func input "
- << const_index;
- (*op_input_to_func_input)[i] = const_index;
- func_input_arg = signature->mutable_input_arg(const_index++);
- const_input_types->push_back(
- static_cast(op->inputs[i]->dtype));
- } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
- VLOG(1) << "For resource input, mapping op input " << i
- << " to func input " << resource_index;
- (*op_input_to_func_input)[i] = resource_index;
- func_input_arg = signature->mutable_input_arg(resource_index++);
- } else {
- VLOG(1) << "For arg input, mapping op input " << i << " to func input "
- << arg_index;
- (*op_input_to_func_input)[i] = arg_index;
- func_input_arg = signature->mutable_input_arg(arg_index++);
- arg_input_types->push_back(
- static_cast(op->inputs[i]->dtype));
- }
-
- func_input_arg->set_name(op_input_arg.name());
- func_input_arg->set_type(op->inputs[i]->dtype);
- }
- VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
-
- // Resources args are at the end of the function input params, and we should
- // have iterated over all of them.
- DCHECK_EQ(signature->input_arg_size(), resource_index);
-
- // Make the synthesized function's name unique.
- signature->set_name(tensorflow::strings::StrCat(
- op_def.name(), func_id_generator.fetch_add(1)));
-
- // Add the node def and set its input names to match op_def's names.
- const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
- DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
- *fdef.add_node_def() = ndef;
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
- }
- VLOG(1) << "Added NodeDef: " << fdef.DebugString();
-
- // Fix the output names and set output types.
- for (int i = 0; i < op_def.output_arg_size(); ++i) {
- tensorflow::OpDef::ArgDef* arg = signature->add_output_arg();
- const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
- const string& out_tensor_name = tensorflow::strings::StrCat(
- ndef.name(), ":", op_def_arg.name(), ":", 0);
- arg->set_name(op_def_arg.name());
- (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
- const string& type_attr = op_def_arg.type_attr();
- if (!type_attr.empty()) {
- auto i = ndef.attr().find(type_attr);
- if (i == ndef.attr().end()) {
- status->status = tensorflow::errors::InvalidArgument(
- tensorflow::strings::StrCat("Could not find attr ", type_attr,
- " in NodeDef ", ndef.DebugString()));
- return nullptr;
- }
- arg->set_type(i->second.type());
- }
- }
- VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
-
- status->status = ctx->context.AddFunctionDef(fdef);
- if (!status->status.ok()) return nullptr;
- const auto ret = ctx->context.FindFunctionDef(signature->name());
- DCHECK(ret != nullptr);
- return ret;
-}
-
-// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
-// via XLA.
-std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
- VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name;
- auto launch_op =
- std::unique_ptr(TFE_NewOp(op->ctx, "_XlaLaunch", status));
- if (TF_GetCode(status) != TF_OK) return nullptr;
- if (op->device) {
- TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- }
-
- const tensorflow::FunctionDef* fdef;
- {
- fdef = op->ctx->context.FindFunctionDef(op->name);
- }
- std::vector const_input_types;
- std::vector arg_input_types;
- tensorflow::gtl::FlatMap op_input_to_func_input;
- if (fdef == nullptr) {
- // See if this is a primitive op, and if so create a function for it, so
- // that _XlaLaunchOp can access it.
- fdef = OpToFunction(op, &const_input_types, &arg_input_types,
- &op_input_to_func_input, status);
- if (!status->status.ok()) return nullptr;
- } else {
- // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
- // functions, so we need to find another way to handle constant inputs.
- for (int i = const_input_types.size();
- i < fdef->signature().input_arg_size(); ++i) {
- VLOG(1) << "Adding Targs from input arg " << i;
- const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i);
- arg_input_types.push_back(static_cast(arg.type()));
- }
- }
- DCHECK(fdef != nullptr);
-
- // Copy inputs and their devices.
- // Since input param reordering may have occurred between `op` and `launch_op`
- // via `op_input_to_func_input`, adjust the actual inputs accordingly.
- launch_op->inputs = op->inputs;
- for (tensorflow::TensorHandle* h : launch_op->inputs) {
- h->Ref();
- }
- if (!op_input_to_func_input.empty()) {
- DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
- for (int i = 0; i < op_input_to_func_input.size(); ++i) {
- VLOG(1) << "mapping op input " << i << " to func input "
- << op_input_to_func_input[i];
-
- launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
- }
- }
- launch_op->attrs.NumInputs(op->inputs.size());
-
- TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
- const_input_types.size());
-
- // Set Targs and Nresources attrs.
- TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
- arg_input_types.size());
- const int num_resource_inputs = fdef->signature().input_arg_size() -
- const_input_types.size() -
- arg_input_types.size();
- TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);
-
- // Set Tresults attr.
- std::vector tresults;
- for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) {
- tresults.push_back(static_cast(arg.type()));
- }
- TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
- tresults.size());
-
- // Set function attr.
- tensorflow::AttrValue attr_value;
- tensorflow::NameAttrList* func = attr_value.mutable_func();
- func->set_name(fdef->signature().name());
- launch_op->attrs.Set("function", attr_value);
-
- return launch_op;
-}
-#endif // TENSORFLOW_EAGER_USE_XLA
-
-} // namespace
-
-extern "C" {
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
- TFE_Context* ctx = op->ctx;
- status->status = ctx->context.GetStatus();
+ tensorflow::gtl::InlinedVector handle_retvals(
+ *num_retvals);
+ status->status =
+ tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals);
if (!status->status.ok()) {
return;
}
-#ifdef TENSORFLOW_EAGER_USE_XLA
- std::unique_ptr xla_launch_op;
- if (op->use_xla && op->name != "_XlaLaunch") {
- xla_launch_op = BuildXlaLaunch(op, status);
- if (!status->status.ok()) {
- return;
- }
- op = xla_launch_op.get();
- }
-#endif // TENSORFLOW_EAGER_USE_XLA
- // Ensure all resource-touching ops run in the device the resource is,
- // regardless of anything else that has been specified. This is identical to
- // the graph mode behavior.
- for (int i = 0; i < op->inputs.size(); ++i) {
- tensorflow::Device* input_op_device = nullptr;
- status->status = op->inputs[i]->OpDevice(&input_op_device);
- if (!status->status.ok()) return;
- VLOG(2) << "for op " << op->name << " input " << i << " "
- << tensorflow::DataTypeString(op->inputs[i]->dtype) << " "
- << (input_op_device == nullptr ? "cpu" : input_op_device->name())
- << " " << (op->device == nullptr ? "cpu" : op->device->name());
- if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE &&
- (input_op_device != op->device || input_op_device == nullptr)) {
- tensorflow::Device* d =
- input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device;
- VLOG(1) << "Changing device of operation " << op->name << " to "
- << d->name() << " because input #" << i
- << " is a resource in this device.";
- op->device = d;
- }
- }
- tensorflow::Device* device = op->device;
-
- tensorflow::Fprint128 cache_key =
- op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name());
- tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key);
- if (kernel == nullptr) {
- const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
- if (device == nullptr) {
- device = SelectDevice(ndef, ctx, status);
- if (!status->status.ok()) {
- return;
- }
- }
- CHECK(device != nullptr);
- if (ctx->context.LogDevicePlacement()) {
- LOG(INFO) << "Executing op " << ndef.op() << " in device "
- << device->name();
- }
- kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous());
- // Knowledge of the implementation of Init (and in-turn
- // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
- // will be accessed, so grab on to the lock.
- // See WARNING comment in Execute (before kernel->Run) - would be nice to
- // rework to avoid this subtlety.
- tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu());
- status->status = tensorflow::KernelAndDevice::Init(
- ndef, ctx->context.func_lib(device), kernel);
- if (!status->status.ok()) {
- delete kernel;
- return;
- }
- // Update output_dtypes inside `kernel`.
- const tensorflow::OpDef* op_def = nullptr;
- const tensorflow::FunctionDef* function_def =
- ctx->context.FuncLibDef()->Find(ndef.op());
- if (function_def != nullptr) {
- op_def = &(function_def->signature());
- }
- if (op_def == nullptr) {
- status->status = OpDefForOp(ndef.op().c_str(), &op_def);
- if (!status->status.ok()) {
- return;
- }
- }
- tensorflow::DataTypeVector input_dtypes;
- status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes,
- kernel->mutable_output_dtypes());
- if (!status->status.ok()) {
- return;
- }
- ctx->context.AddKernelToCache(cache_key, kernel);
- }
- const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
- const int output_dtypes_size = output_dtypes.size();
- if (output_dtypes_size > *num_retvals) {
- TF_SetStatus(status, TF_INVALID_ARGUMENT,
- tensorflow::strings::StrCat("Expecting ", output_dtypes.size(),
- " outputs, but *num_retvals is ",
- *num_retvals)
- .c_str());
- return;
- }
- *num_retvals = output_dtypes_size;
- if (device == nullptr) {
- // TODO(apassos) debug how the assignment below might return a different
- // device from the one requested above.
- device = kernel->device();
- }
- status->status = ValidateInputTypeAndPlacement(
- &ctx->context, device, op, kernel->kernel(),
- ctx->context.ShouldStoreMetadata() ? ctx->context.RunMetadataProto()
- : nullptr);
- if (!status->status.ok()) return;
- std::unique_ptr maybe_stats;
- if (ctx->context.ShouldStoreMetadata()) {
- maybe_stats.reset(new tensorflow::NodeExecStats);
- maybe_stats->set_node_name(op->name);
- maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
- maybe_stats->set_op_start_rel_micros(0);
- maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
- // TODO(apassos) track referenced tensors
- }
- if (ctx->context.Async()) {
- // Note that for async mode, execution order will make sure that all
- // input handles are ready before executing them.
- // TODO(agarwal): Consider executing "cheap" kernels inline for performance.
- tensorflow::gtl::InlinedVector handle_retvals(
- *num_retvals);
- tensorflow::uint64 id = op->ctx->context.NextId();
- for (int i = 0; i < *num_retvals; ++i) {
- tensorflow::TensorHandle* h =
- new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context);
- retvals[i] = new TFE_TensorHandle(h);
- handle_retvals[i] = h;
- }
- tensorflow::EagerNode* node = new tensorflow::ExecuteNode(
- id, &op->ctx->context, op->device, op->inputs, kernel,
- maybe_stats.release(), output_dtypes, handle_retvals);
- ctx->context.ExecutorAdd(node);
- } else {
- // Execute checks if retvals[i] is nullptr or not to figure if it needs to
- // allocate it.
- std::vector handle_retvals(*num_retvals,
- nullptr);
- status->status = tensorflow::EagerExecute(
- &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(),
- handle_retvals.data(), *num_retvals);
- for (int i = 0; i < *num_retvals; ++i) {
- retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
- }
+ for (int i = 0; i < *num_retvals; ++i) {
+ retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
}
}
@@ -1090,10 +711,3 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
}
}
} // namespace tensorflow
-
-
-TFE_Op::~TFE_Op() {
- for (tensorflow::TensorHandle* h : inputs) {
- h->Unref();
- }
-}
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index c06ce84a8c5..574a097e0d6 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -81,6 +81,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
+// A tensorflow.ServerDef specifies remote workers (in addition to the current
+// workers name). Operations created on this context can then be executed on
+// any of these remote workers by setting an appropriate device.
+//
+// If the following is set, all servers identified by the
+// ServerDef must be up when the context is created.
+TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
+ TFE_ContextOptions* options, const void* proto, size_t proto_len,
+ TF_Status* status);
+
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 05dc64f5217..2b8384d7203 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -28,14 +28,23 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/c/eager/runtime.h"
#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
+#include "tensorflow/core/distributed_runtime/remote_device.h"
+#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
+#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -45,12 +54,12 @@ limitations under the License.
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
-
struct TFE_ContextOptions {
TF_SessionOptions session_options;
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
+ tensorflow::ServerDef server_def;
};
struct TFE_Context {
@@ -64,6 +73,23 @@ struct TFE_Context {
default_policy),
async, std::move(device_mgr), rendezvous) {}
+ explicit TFE_Context(
+ const tensorflow::SessionOptions& opts,
+ TFE_ContextDevicePlacementPolicy default_policy, bool async,
+ tensorflow::DeviceMgr* local_device_mgr,
+ tensorflow::Rendezvous* rendezvous,
+ std::unique_ptr server,
+ std::unique_ptr remote_eager_workers,
+ std::unique_ptr remote_device_mgr,
+ const tensorflow::gtl::FlatMap&
+ remote_contexts)
+ : context(opts,
+ static_cast(
+ default_policy),
+ async, local_device_mgr, rendezvous, std::move(server),
+ std::move(remote_eager_workers), std::move(remote_device_mgr),
+ remote_contexts) {}
+
tensorflow::EagerContext context;
};
@@ -85,19 +111,9 @@ struct TFE_Op {
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
// primitive operation.
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
- : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
+ : operation(&ctx->context, op, t) {}
- ~TFE_Op();
-
- bool const is_function() const { return attr_types == nullptr; }
-
- TFE_Context* ctx; // Must outlive the TFE_Op.
- const tensorflow::string name;
- tensorflow::AttrBuilder attrs;
- const tensorflow::AttrTypeMap* attr_types;
- tensorflow::gtl::InlinedVector inputs;
- tensorflow::Device* device;
- bool use_xla = false;
+ tensorflow::EagerOperation operation;
};
namespace tensorflow {
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 701175e4943..49646bb7359 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include
+#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -23,7 +24,9 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using tensorflow::string;
@@ -220,6 +223,103 @@ TEST(CAPI, Context) {
TF_DeleteStatus(status);
}
+tensorflow::ServerDef GetServerDef(int num_tasks) {
+ tensorflow::ServerDef server_def;
+ server_def.set_protocol("grpc");
+ server_def.set_job_name("localhost");
+ server_def.set_task_index(0);
+ tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
+ tensorflow::JobDef* job_def = cluster_def->add_job();
+ job_def->set_name("localhost");
+ for (int i = 0; i < num_tasks; i++) {
+ int port = tensorflow::testing::PickUnusedPortOrDie();
+ job_def->mutable_tasks()->insert(
+ {i, tensorflow::strings::StrCat("localhost:", port)});
+ }
+ return server_def;
+}
+
+void TestRemoteExecute(bool async) {
+ tensorflow::ServerDef server_def = GetServerDef(2);
+
+ // This server def has the task index set to 0.
+ string serialized = server_def.SerializeAsString();
+
+ server_def.set_task_index(1);
+
+ std::unique_ptr worker_server;
+ ASSERT_TRUE(
+ tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
+ .ok());
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
+ status);
+ TFE_ContextOptionsSetAsync(opts, static_cast(1));
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
+ TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
+ const char remote_device_name[] =
+ "/job:localhost/replica:0/task:1/device:CPU:0";
+ auto* h0_task1 =
+ TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ auto* h1_task1 =
+ TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
+ TFE_OpSetDevice(matmul, remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ auto* retval_task0 = TFE_TensorHandleCopyToDevice(
+ retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteTensorHandle(retval_task0);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+
+ TFE_DeleteTensorHandle(h0_task0);
+ TFE_DeleteTensorHandle(h1_task0);
+ TFE_DeleteTensorHandle(h0_task1);
+ TFE_DeleteTensorHandle(h1_task1);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(matmul);
+
+ TFE_ContextAsyncWait(ctx, status);
+ TFE_DeleteContext(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_DeleteStatus(status);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+}
+
+TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
+TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
+
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 97c323b8722..dcc2357b71a 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -130,13 +130,15 @@ class GradientTape {
}
}
- bool ShouldRecord(gtl::ArraySlice tensor_ids);
+ bool ShouldRecord(gtl::ArraySlice tensor_ids,
+ gtl::ArraySlice dtypes);
void Watch(int64 tensor_id);
void RecordOperation(const string& op_type,
gtl::ArraySlice output_tensors,
gtl::ArraySlice input_tensor_id,
+ gtl::ArraySlice input_dtypes,
BackwardFunction* backward_function,
const std::function& backward_function_deleter);
@@ -170,12 +172,32 @@ class GradientTape {
// Template instantiations here
+inline bool IsDtypeTrainable(DataType dtype) {
+ switch (dtype) {
+ case DT_HALF:
+ case DT_BFLOAT16:
+ case DT_FLOAT:
+ case DT_DOUBLE:
+ case DT_COMPLEX64:
+ case DT_COMPLEX128:
+ case DT_RESOURCE:
+ case DT_VARIANT:
+ return true;
+ default:
+ return false;
+ }
+}
+
template
bool GradientTape::ShouldRecord(
- gtl::ArraySlice tensor_ids) {
- for (int64 i : tensor_ids) {
- if (tensor_tape_.find(i) != tensor_tape_.end()) {
- return true;
+ gtl::ArraySlice tensor_ids,
+ gtl::ArraySlice dtypes) {
+ CHECK_EQ(tensor_ids.size(), dtypes.size());
+ for (int i = 0; i < tensor_ids.size(); ++i) {
+ if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
+ if (IsDtypeTrainable(dtypes[i])) {
+ return true;
+ }
}
}
return false;
@@ -189,9 +211,11 @@ void GradientTape::Watch(int64 tensor_id) {
template
void GradientTape::RecordOperation(
const string& op_type, gtl::ArraySlice output_tensors,
- gtl::ArraySlice input_tensor_id, BackwardFunction* backward_function,
+ gtl::ArraySlice input_tensor_id,
+ gtl::ArraySlice input_dtypes,
+ BackwardFunction* backward_function,
const std::function& backward_function_deleter) {
- if (!ShouldRecord(input_tensor_id)) {
+ if (!ShouldRecord(input_tensor_id, input_dtypes)) {
backward_function_deleter();
return;
}
@@ -380,49 +404,39 @@ Status InitialGradients(const VSpace& vspace,
gtl::ArraySlice output_gradients,
const TensorTape& tensor_tape,
const OpTape& op_tape,
- const gtl::FlatMap& tensor_usage_counts,
gtl::FlatMap>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
- if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
- if (!output_gradients.empty() && output_gradients[i] != nullptr) {
- // TODO(apassos) figure out how to print debugging information here.
- return errors::InvalidArgument(
- "A gradient was provided for a tensor which is used as part of the "
- "computation.");
- }
- } else {
- if (output_gradients.empty() || output_gradients[i] == nullptr) {
- auto tensor_it = tensor_tape.find(id);
- if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
- auto op_it = op_tape.find(tensor_it->second);
- if (op_it == op_tape.end()) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "failed to find operation producing a tensor");
+ if (output_gradients.empty() || output_gradients[i] == nullptr) {
+ auto tensor_it = tensor_tape.find(id);
+ if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
+ auto op_it = op_tape.find(tensor_it->second);
+ if (op_it == op_tape.end()) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "failed to find operation producing a tensor");
+ }
+ bool found = false;
+ for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
+ if (op_it->second.output_tensor_info[j].id == id) {
+ found = true;
+ (*result)[id].push_back(
+ vspace.Ones(op_it->second.output_tensor_info[j].shape,
+ op_it->second.output_tensor_info[j].dtype));
+ break;
}
- bool found = false;
- for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
- found = true;
- (*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
- break;
- }
- }
- if (!found) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "none of operations outputs match expected tensor");
- }
- } else {
- // No record of the target tensor found on the tape, so no gradient
- // needs to be computed from it. Do nothing.
+ }
+ if (!found) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "none of operations outputs match expected tensor");
}
} else {
- (*result)[id].push_back(output_gradients[i]);
+ // No record of the target tensor found on the tape, so no gradient
+ // needs to be computed from it. Do nothing.
}
+ } else {
+ (*result)[id].push_back(output_gradients[i]);
}
}
return Status::OK();
@@ -451,8 +465,7 @@ Status GradientTape::ComputeGradient(
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
- tensor_tape_, state.op_tape,
- state.tensor_usage_counts, &gradients);
+ tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
if (!persistent_) {
// Release all backprop functions
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 93155998b86..e18fdf6c57b 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
-std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
@@ -135,4 +135,30 @@ std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
return result;
}
+void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
+ const void* proto, size_t proto_len,
+ TF_Status* status) {
+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
+ if (!handle_data.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Couldn't deserialize HandleData proto");
+ return;
+ }
+ DCHECK(handle_data.is_set());
+
+ tensorflow::mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(&output.oper->node);
+
+ std::vector shapes_and_types;
+ for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
+ tensorflow::shape_inference::ShapeHandle shape;
+ status->status =
+ ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
+ if (status->status.ok()) return;
+ shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
+ }
+ ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 2d4c8cd9ed7..4bcb5bde62c 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -55,9 +55,15 @@ void ExtendSession(TF_Session* session, TF_Status* status);
// Returns the serialized CppShapeInferenceResult::HandleData proto for
// `output` if its a resource tensor, or otherwise returns the empty string.
-// TODO(b/74620627): remove when _USE_C_SHAPES is removed
-std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+// Sets `output` based on `proto`, which should be a serialized
+// CppShapeInferenceResult::HandleData proto.
+// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
+// because I couldn't get SWIG to work otherwise.
+void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
+ const void* proto, size_t proto_len,
+ TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index d73121c7b70..d6a4f141b6b 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -440,7 +440,7 @@ string AvoidCPPKeywords(StringPiece name) {
if (IsCPPKeyword(name)) {
return strings::StrCat(name, "_");
}
- return name.ToString();
+ return std::string(name);
}
void InferArgAttributes(const OpDef::ArgDef& arg,
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index c143b978338..62a889181e7 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -220,7 +220,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints(
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
- current_constraints.insert(s.ToString());
+ current_constraints.insert(std::string(s));
}
}
} else {
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index 6545e4ee3eb..ff348fadb24 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -385,6 +385,42 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
+Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
+ const std::vector