Merge branch 'master' into master

This commit is contained in:
XinPing Wang 2019-05-17 16:41:51 +08:00 committed by GitHub
commit a37ddd9c78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5495 changed files with 417946 additions and 190748 deletions

View File

@ -105,9 +105,6 @@ build --define=PREFIX=/usr
build --define=LIBDIR=$(PREFIX)/lib
build --define=INCLUDEDIR=$(PREFIX)/include
# Disable MKL-DNN contraction kernels by default.
build --define=tensorflow_mkldnn_contraction_kernel=0
# Default options should come above this line
# Options from ./configure

View File

@ -18,10 +18,11 @@ about: Use this template for reporting a bug or a performance issue.
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**

View File

@ -1,17 +1,55 @@
---
name: Documentation Issue
about: Use this template for documentation related issues
about: Use this template for documentation related
labels: 'type:docs'
---
<em>Please make sure that this is a documentation issue. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:doc_template</em>
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
policy, we only address code/doc bugs, performance issues, feature requests, and
build/installation issues on GitHub.
The TensorFlow docs are open source! To get involved, read the documentation
contributor guide: https://www.tensorflow.org/community/contribute/docs
**System information**
- TensorFlow version:
- Doc Link:
## URL(s) with the issue:
Please provide a link to the documentation entry, for example:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/MyMethod
**Describe the documentation issue**
## Description of issue (what needs changing):
**We welcome contributions by users. Will you be able to update submit a PR (use the [doc style guide](https://www.tensorflow.org/community/documentation)) to fix the doc Issue?**
### Clear description
For example, why should someone use this method? How is it useful?
### Correct links
Is the link to the source code correct?
### Parameters defined
Are all parameters defined and formatted correctly?
### Returns defined
Are return values defined?
### Raises listed and defined
Are the errors defined? For example,
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/feature_column/categorical_column_with_vocabulary_file#raises
### Usage example
Is there a usage example?
### Request visuals, if applicable
Are there currently visuals? If not, will it clarify the content?
### Submit a pull request?
Are you planning to also submit a pull request to fix the issue? See the docs
contributor guide: https://www.tensorflow.org/community/contribute/docs and the
docs style guide: https://www.tensorflow.org/community/contribute/docs_style

20
.gitignore vendored
View File

@ -20,15 +20,8 @@ tensorflow/contrib/cmake/_build/
[Bb]uild/
/tensorflow/core/util/version_info.cc
/tensorflow/python/framework/fast_tensor_util.cpp
Pods
Podfile.lock
*.pbxproj
*.xcworkspacedata
/tensorflow/lite/tools/make/downloads/**
/tensorflow/lite/gen/**
/tensorflow/lite/examples/ios/simple/data/*.txt
/tensorflow/lite/examples/ios/simple/data/*.tflite
xcuserdata/**
/tensorflow/lite/tools/make/downloads/**
/api_init_files_list.txt
/estimator_api_init_files_list.txt
*.whl
@ -39,3 +32,14 @@ xcuserdata/**
*.iml
local.properties
gradleBuild
# iOS
*.pbxproj
*.xcworkspace
/*.podspec
/tensorflow/lite/**/[ios|objc|swift]*/BUILD
/tensorflow/lite/examples/ios/simple/data/*.tflite
/tensorflow/lite/examples/ios/simple/data/*.txt
Podfile.lock
Pods
xcuserdata

View File

@ -32,7 +32,7 @@ https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
You can obtain the TensorFlow version with:
```bash
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
```
### Describe the problem

View File

@ -1,4 +1,4 @@
Copyright 2018 The TensorFlow Authors. All rights reserved.
Copyright 2019 The TensorFlow Authors. All rights reserved.
Apache License
Version 2.0, January 2004

View File

@ -1,5 +1,5 @@
<div align="center">
<img src="https://www.tensorflow.org/images/tf_logo_transp.png"><br><br>
<img src="https://www.tensorflow.org/images/tf_logo_social.png">
</div>
-----------------
@ -25,7 +25,7 @@ networks research. The system is general enough to be applicable in a wide
variety of other domains, as well.
TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards
compatible API's for C++, Go, Java, JavaScript and Swift.
compatible API's for C++, Go, Java, JavaScript, and Swift.
Keep up to date with release announcements and security updates by
subscribing to
@ -50,10 +50,10 @@ instructions, and how to build from source.*
People who are a little more adventurous can also try our nightly binaries:
**Nightly pip packages**
* We are pleased to announce that TensorFlow now offers nightly pip packages
under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on pypi.
**Nightly pip packages** * We are pleased to announce that TensorFlow now offers
nightly pip packages under the
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on PyPi.
Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean
environment to install the nightly TensorFlow build. We support CPU and GPU
packages on Linux, Mac, and Windows.
@ -85,7 +85,7 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's
uphold this code.**
**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for
tracking requests and bugs, so please see
tracking requests and bugs, please see
[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss)
for general questions and discussion, and please direct specific questions to
[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
@ -114,15 +114,16 @@ The TensorFlow project strives to abide by generally accepted best practices in
### Community Supported Builds
Build Type | Status | Artifacts
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.4<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.12.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.12.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp34-cp34m-linux_x86_64.whl)<br>[1.12.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp35-cp35m-linux_x86_64.whl)<br>[1.12.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl)
Build Type | Status | Artifacts
--------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5, and 3.6** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 pypi](https://tensorflow.pypi.thoth-station.ninja/index/)
## For more information

View File

@ -1,3 +1,212 @@
# Release 1.12.2
## Bug Fixes and Other Changes
* Fixes a potential security vulnerability where carefully crafted GIF images
can produce a null pointer dereference during decoding.
# Release 1.13.0
## Major Features and Improvements
* TensorFlow Lite has moved from contrib to core. This means that Python modules are under `tf.lite` and source code is now under `tensorflow/lite` rather than `tensorflow/contrib/lite`.
* TensorFlow GPU binaries are now built against CUDA 10 and TensorRT 5.0.
* Support for Python3.7 on all operating systems.
* Moved NCCL to core.
## Behavioral changes
* Disallow conversion of python floating types to uint32/64 (matching behavior of other integer types) in `tf.constant`.
* Make the `gain` argument of convolutional orthogonal initializers (`convolutional_delta_orthogonal`, `convolutional_orthogonal_1D`, `convolutional_orthogonal_2D`, `convolutional_orthogonal_3D`) have consistent behavior with the `tf.initializers.orthogonal` initializer, i.e. scale the output l2-norm by `gain` and NOT by `sqrt(gain)`. (Note that these functions are currently in `tf.contrib` which is not guaranteed backward compatible).
## Bug Fixes and Other Changes
* Documentation
* Update the doc with the details about the rounding mode used in
quantize_and_dequantize_v2.
* Clarify that tensorflow::port::InitMain() _should_ be called before
using the TensorFlow library. Programs failing to do this are not
portable to all platforms.
* Deprecations and Symbol renames.
* Removing deprecations for the following endpoints: `tf.acos`,
`tf.acosh`, `tf.add`, `tf.as_string`, `tf.asin`, `tf.asinh`, `tf.atan`,
`tf.atan2`, `tf.atanh`, `tf.cos`, `tf.cosh`, `tf.equal`, `tf.exp`,
`tf.floor`, `tf.greater`, `tf.greater_equal`, `tf.less`,
`tf.less_equal`, `tf.log`, `tf.logp1`, `tf.logical_and`,
`tf.logical_not`, `tf.logical_or`, `tf.maximum`, `tf.minimum`,
`tf.not_equal`, `tf.sin`, `tf.sinh`, `tf.tan`
* Deprecate `tf.data.Dataset.shard`.
* Deprecate `saved_model.loader.load` which is replaced by
`saved_model.load` and `saved_model.main_op`, which will be replaced by
`saved_model.main_op` in V2.
* Deprecate tf.QUANTIZED_DTYPES. The official new symbol is
tf.dtypes.QUANTIZED_DTYPES.
* Update sklearn imports for deprecated packages.
* Deprecate `Variable.count_up_to` and `tf.count_up_to` in favor of
`Dataset.range`.
* Export `confusion_matrix` op as `tf.math.confusion_matrix` instead of
`tf.train.confusion_matrix`.
* Add `tf.dtypes.` endpoint for every constant in dtypes.py. Moving
endpoints in versions.py to corresponding endpoints in `tf.sysconfig.`
and `tf.version.`. Moving all constants under `tf.saved_model`
submodules to `tf.saved_model` module. New endpoints are added in V1 and
V2 but existing endpoint removals are only applied in V2.
* Deprecates behavior where device assignment overrides collocation
constraints inside a collocation context manager.
* Keras & Python API
* Add to Keras functionality analogous to
`tf.register_tensor_conversion_function`.
* Subclassed Keras models can now be saved through
`tf.contrib.saved_model.save_keras_model`.
* `LinearOperator.matmul` now returns a new `LinearOperator`.
* New ops and improved op functionality
* Add a Nearest Neighbor Resize op.
* Add an `ignore_unknown` argument to `parse_values` which suppresses
ValueError for unknown hyperparameter types. Such * Add
`tf.linalg.matvec` convenience function.
* `tf.einsum()`raises `ValueError` for unsupported equations like
`"ii->"`.
* Add DCT-I and IDCT-I in `tf.signal.dct` and `tf.signal.idct`.
* Add LU decomposition op.
* Add quantile loss to gradient boosted trees in estimator.
* Add `round_mode` to `QuantizeAndDequantizeV2` op to select rounding
algorithm.
* Add `unicode_encode`, `unicode_decode`, `unicode_decode_with_offsets`,
`unicode_split`, `unicode_split_with_offset`, and `unicode_transcode`
ops. Amongst other things, this Op adds the ability to encode, decode,
and transcode a variety of input text encoding formats into the main
Unicode encodings (UTF-8, UTF-16-BE, UTF-32-BE)
* Add "unit" attribute to the substr op, which allows obtaining the
substring of a string containing unicode characters.
* Broadcasting support for Ragged Tensors.
* `SpaceToDepth` supports uint8 data type.
* Support multi-label quantile regression in estimator.
* We now use "div" as the default partition_strategy in
`tf.nn.safe_embedding_lookup_sparse`, `tf.nn.sampled_softmax` and
`tf.nn.nce_loss`. hyperparameter are ignored.
* Performance
* Improve performance of GPU cumsum/cumprod by up to 300x.
* Added support for weight decay in most TPU embedding optimizers,
including AdamW and MomentumW.
* TensorFlow 2.0 Development
* Add a command line tool to convert to TF2.0, tf_upgrade_v2
* Merge `tf.spectral` into `tf.signal` for TensorFlow 2.0.
* Change the default recurrent activation function for LSTM from
'hard_sigmoid' to 'sigmoid' in 2.0. Historically recurrent activation is
'hard_sigmoid' since it is fast than 'sigmoid'. With new unified backend
between CPU and GPU mode, since the CuDNN kernel is using sigmoid, we
change the default for CPU mode to sigmoid as well. With that, the
default LSTM will be compatible with both CPU and GPU kernel. This will
enable user with GPU to use CuDNN kernel by default and get a 10x
performance boost in training. Note that this is checkpoint breaking
change. If user want to use their 1.x pre-trained checkpoint, please
construct the layer with LSTM(recurrent_activation='hard_sigmoid') to
fallback to 1.x behavior.
* TensorFlow Lite
* Move from `tensorflow/contrib/lite` to `tensorflow/lite`.
* Add experimental Java API for injecting TensorFlow Lite delegates
* Add support for strings in TensorFlow Lite Java API.
* `tf.contrib`:
* Add Apache Ignite Filesystem plugin to support accessing Apache IGFS.
* Dropout now takes `rate` argument, `keep_prob` is deprecated.
* Estimator occurrences references `tf.contrib.estimator` were changed to
`tf.estimator`:
* `tf.contrib.estimator.BaselineEstimator` with
`tf.estimator.BaselineEstimator`
* `tf.contrib.estimator.DNNLinearCombinedEstimator` with
`tf.estimator.DNNLinearCombinedEstimator`
* `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator`
* `tf.contrib.estimator.LinearEstimator` with
`tf.estimator.LinearEstimator`
* `tf.contrib.estimator.InMemoryEvaluatorHook` and
tf.estimator.experimental.InMemoryEvaluatorHook`.
* `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with
`tf.estimator.experimental.make_stop_at_checkpoint_step_hook`.
* Expose `tf.distribute.Strategy as the new name for
tf.contrib.distribute.DistributionStrategy.
* Migrate linear optimizer from contrib to core.
* Move `tf.contrib.signal` to `tf.signal` (preserving aliases in
tf.contrib.signal).
* Users of `tf.contrib.estimator.export_all_saved_models` and related
should switch to
`tf.estimator.Estimator.experimental_export_all_saved_models`.
* tf.data:
* Add `tf.data.experimental.StatsOptions()`, to configure options to
collect statistics from `tf.data.Dataset` pipeline using
`StatsAggregator`. Add nested option, `experimental_stats` (which takes
a `tf.data.experimen tal.StatsOptions` object), to `tf.data.Options`.
Deprecates `tf.data.experimental.set_stats_agregator`.
* Performance optimizations:
* Add `tf.data.experimental.OptimizationOptions()`, to configure options
to enable `tf.data` performance optimizations. Add nested option,
`experimental_optimization` (which takes a
`tf.data.experimental.OptimizationOptions` object), to
`tf.data.Options`. Remove performance optimization options from
`tf.data.Options`, and add them under
`tf.data.experimental.OptimizationOptions` instead.
* Enable `map_and_batch_fusion` and `noop_elimination` optimizations by
default. They can be disabled by configuring
`tf.data.experimental.OptimizationOptions` to set `map_and_batch =
False` or `noop_elimination = False` respectively. To disable all
default optimizations, set `apply_default_optimizations = False`.
* Support parallel map in `map_and_filter_fusion`.
* Disable static optimizations for input pipelines that use non-resource
`tf.Variable`s.
* Add NUMA-aware MapAndBatch dataset.
* Deprecate `tf.data.Dataset.make_one_shot_iterator()` in V1, removed it
from V2, and added tf.compat.v1.data.make_one_shot_iterator()`.
* Deprecate `tf.data.Dataset.make_initializable_iterator()` in V1, removed
it from V2, and added `tf.compat.v1.data.make_initializable_iterator()`.
* Enable nested dataset support in core `tf.data` transformations.
* For `tf.data.Dataset` implementers: Added
`tf.data.Dataset._element_structured property` to replace
`Dataset.output_{types,shapes,classes}`.
* Make `num_parallel_calls` of `tf.data.Dataset.interleave` and
`tf.data.Dataset.map` work in Eager mode.
* Toolchains
* Fixed OpenSSL compatibility by avoiding `EVP_MD_CTX_destroy`.
* Added bounds checking to printing deprecation warnings.
* Upgraded CUDA dependency to 10.0
* To build with Android NDK r14b, add "#include <linux/compiler.h>" to
android-ndk-r14b/platforms/android-14/arch-*/usr/include/linux/futex.h
* Removed `:android_tensorflow_lib_selective_registration*` targets, use
`:android_tensorflow_lib_lite*` targets instead.
* XLA
* Move `RoundToEven` function to xla/client/lib/math.h.
* A new environment variable `TF_XLA_DEBUG_OPTIONS_PASSTHROUGH` set to "1"
or "true" allows the debug options passed within an XRTCompile op to be
passed directly to the XLA compilation backend. If such variable is not
set (service side), only a restricted set will be passed through.
* Allow the XRTCompile op to return the ProgramShape resulted form the XLA
compilation as a second return argument.
* XLA HLO graphs can now be rendered as SVG/HTML.
* Estimator
* Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with
`tf.estimator.BaselineEstimator`
* Replace all occurences of
`tf.contrib.estimator.DNNLinearCombinedEstimator` with
`tf.estimator.DNNLinearCombinedEstimator`
* Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with
`tf.estimator.DNNEstimator`
* Replace all occurrences of `tf.contrib.estimator.LinearEstimator` with
`tf.estimator.LinearEstimator`
* Users of `tf.contrib.estimator.export_all_saved_models` and related
should switch to
`tf.estimator.Estimator.experimental_export_all_saved_models`.
* Update `regression_head` to the new Head API for Canned Estimator V2.
* Switch `multi_class_head` to Head API for Canned Estimator V2.
* Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook`
and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with
`tf.estimator.experimental.InMemoryEvaluatorHook` and
`tf.estimator.experimental.make_stop_at_checkpoint_step_hook`
* Migrate linear optimizer from contrib to core.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
Abhinav Upadhyay, Ag Ramesh, akikaaa, Alexis Louis, Anders Huss, Andreas Madsen, Andrew Banchich, Andy Craze, Anton Dmitriev, Artem Malykh, Avijit-Nervana, Balint Cristian, Benjamin Tan Wei Hao, Bhavani Subramanian, Brendan Finan, Brian Nemsick, Bryan Cutler, By Shen, Cao Zongyan, Castiel, Chris Antaki, Christian Goll, Cibifang, Clayne Robison, Codrut Grosu, Cong Xu, Dalmo Cirne, Daniel Hunter, Dougal J. Sutherland, Edvard Fagerholm, EFanZh, Erik Smistad, Evgeniy Polyakov, Feiyang Chen, franklin5, Fred Reiss, Gautam, gehring, Geoffrey Irving, George Sterpu, Gitea, Grzegorz George Pawelczak, Guozhong Zhuang, himkt, Hoeseong Kim, Huan Li (李卓桓), HuiyangFei, hyunyoung, Isaac Burbank, jackonan, Jacky Ko, Jason Furmanek, Jason Zaman, Javier Luraschi, Jiang,Zhoulong, joaak, John Lin, Jonathan Wyatt Hoech, josephyearsley, Josh Gordon, Julian Niedermeier, Karl Lessard, Keno Fischer, lanhin, Leon Graser, leondgarse, Li, Guizi, Li, Yiqiang, lxl910915, Mahmoud Abuzaina, manhyuk, Marcela Morales Quispe, margaretmz, Matt Conley, Max Pumperla, mbhuiyan, mdfaijul, Meng, Peng, Michael, Michael Gielda, mrTsjolder, Muhammad Wildan, neargye, Nehal J Wani, NEWPLAN, Niranjan Hasabnis, Nutti, olicht, Pan Daoxin, Pedro Monreal, Peng Yu, pillarpond, Pooya Davoodi, qiezi, Rholais Lii, Richard Yu, Rin Arakaki, Roger Iyengar, sahilbadyal, Sami Kama, Sandip Giri, Scott Leishman, Serge Panev, Seunghoon Park, Shafi Dayatar, shengfuintel, Shimin Guo, Siju, silent567, Stefan Dyulgerov, steven, Tao Wei, Thor Johnsen, Tingbo Lu, tomguluson92, Tongxuan Liu, Trevor Morris, Ubuntu, Vadim Borisov, vanderliang, wangsiyu, Wen Yun, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, Xiaoming (Jason) Cui, Yan Facai (颜发才), Yanbo Liang, Yaniv Blumenfeld, Yash Gaurkar, Yicheng Fan, Yong Tang, Yongjoon Lee, Yuan (Terry) Tang, Yuxin Wu, zldrobit
# Release 1.12.0
## Major Features and Improvements
@ -38,21 +247,21 @@
* Remove integer types from `tf.nn.softplus` and `tf.nn.softsign` OpDefs.
This is a bugfix; these ops were never meant to support integers.
* Allow subslicing Tensors with a single dimension.
* Add option to calculate string length in Unicode characters
* Add option to calculate string length in Unicode characters.
* Add functionality to SubSlice a tensor.
* Add searchsorted (ie lower/upper_bound) op.
* Add model explainability to Boosted Trees.
* Support negative positions for tf.substr
* Support negative positions for tf.substr.
* There was previously a bug in the bijector_impl where the
_reduce_jacobian_det_over_event does not handle scalar ILDJ
implementations properly.
* In tf eager execution, allow re-entering a GradientTape context
* In tf eager execution, allow re-entering a GradientTape context.
* Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in,
then bazel will build TensorFlow API version 2.0. Note that TensorFlow
2.0 is under active development and has no guarantees at this point.
* Add additional compression options to TfRecordWriter
* Add additional compression options to TfRecordWriter.
* Performance improvements for regex full match operations.
* Replace tf.GraphKeys.VARIABLES with `tf.GraphKeys.GLOBAL_VARIABLES`
* Replace tf.GraphKeys.VARIABLES with `tf.GraphKeys.GLOBAL_VARIABLES`.
* Remove unused dynamic learning rate support.
## Thanks to our Contributors
@ -75,15 +284,22 @@ Facai (颜发才), Yanbo Liang, Yash Katariya, Yong Tang, 在原佐为
## Major Features and Improvements
* Nvidia GPU:
* Prebuilt binaries are now (as of TensorFlow 1.11) built against cuDNN 7.2 and TensorRT 4. See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support)
* Google Cloud TPU:
* Experimental tf.data integration for Keras on Google Cloud TPUs.
* Experimental / preview support for eager execution on Google Cloud TPUs.
* DistributionStrategy:
* Add multi-GPU DistributionStrategy support in tf.keras. Users can now use `fit`, `evaluate` and `predict` to distribute their model on multiple GPUs.
* Add multi-worker DistributionStrategy and standalone client support in Estimator. See [README] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute) for more details.
* Add C, C++, and Python functions for querying kernels
* Nvidia GPU:
* Prebuilt binaries are now (as of TensorFlow 1.11) built against cuDNN
7.2 and TensorRT 4. See updated install guides:
[Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support)
* Google Cloud TPU:
* Experimental tf.data integration for Keras on Google Cloud TPUs.
* Experimental / preview support for eager execution on Google Cloud TPUs.
* DistributionStrategy:
* Add multi-GPU DistributionStrategy support in tf.keras. Users can now
use `fit`, `evaluate` and `predict` to distribute their model on
multiple GPUs.
* Add multi-worker DistributionStrategy and standalone client support in
Estimator. See
[README](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute)
for more details.
* Add C, C++, and Python functions for querying kernels.
## Breaking Changes
@ -134,18 +350,18 @@ Facai (颜发才), Yanbo Liang, Yash Katariya, Yong Tang, 在原佐为
* Deprecate self.test_session() in favor of self.session() or
self.cached_session().
* Directly import tensor.proto.h (the transitive import will be removed
from tensor.h soon)
from tensor.h soon).
* Estimator.train() now supports tf.contrib.summary.\* summaries out of
the box; each call to .train() will now create a separate tfevents file
rather than re-using a shared one.
* Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term
should not end up in the accumulator.
* Fix toco compilation/execution on Windows
* Fix toco compilation/execution on Windows.
* GoogleZoneProvider class added to detect which Google Cloud Engine zone
tensorflow is running in.
* It is now safe to call any of the C API's TF_Delete\* functions on
nullptr
* Log some errors on Android to logcat
nullptr.
* Log some errors on Android to logcat.
* Match FakeQuant numerics in TFLite to improve accuracy of TFLite
quantized inference models.
* Optional bucket location check for the GCS Filesystem.
@ -166,7 +382,7 @@ Facai (颜发才), Yanbo Liang, Yash Katariya, Yong Tang, 在原佐为
the existing zero_state() method.
* Update initialization of variables in Keras.
* Updates to "constrained_optimization" in tensorflow/contrib.
* boosted trees: adding pruning mode
* boosted trees: adding pruning mode.
* tf.train.Checkpoint does not delete old checkpoints by default.
* tfdbg: Limit the total disk space occupied by dumped tensor data to 100
GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow

View File

@ -4,11 +4,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file"
http_archive(
name = "io_bazel_rules_closure",
sha256 = "43c9b882fa921923bcba764453f4058d102bece35a37c9f6383c713004aacff1",
strip_prefix = "rules_closure-9889e2348259a5aad7e805547c1a0cf311cfcd91",
sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9",
strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", # 2018-12-21
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04
],
)
@ -43,17 +43,37 @@ remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "73b4980a318d203d3307f850e27e66ec5cc8d223147a3475a6f11597eb6438a5",
strip_prefix = "rules_apple-0.13.0",
urls = ["https://github.com/bazelbuild/rules_apple/archive/0.13.0.tar.gz"],
)
sha256 = "23792cd999f97fc97284d1c44cb1324bfdd0bc54aa68ad513fa3705aca3b1f9e",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.15.0/rules_apple.0.15.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "9efe9699e9765e6b4a5e063e4a08f6b163cccaf0443f775d935baf5c3cd6ed0e",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.9.0/rules_swift.0.9.0.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.5.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.5.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"],
)
load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies")
apple_rules_dependencies()
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
swift_rules_dependencies()

View File

@ -33,13 +33,11 @@ except ImportError:
from distutils.spawn import find_executable as which
# pylint: enable=g-import-not-at-top
_DEFAULT_CUDA_VERSION = '10.0'
_DEFAULT_CUDA_VERSION = '10'
_DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_TENSORRT_VERSION = '5'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
@ -50,21 +48,24 @@ _DEFAULT_PROMPT_ASK_ATTEMPTS = 10
_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
]
# List of files to be configured for using Bazel on Apple platforms.
# List of files to configure when building Bazel on Apple platforms.
APPLE_BAZEL_FILES = [
'tensorflow/lite/experimental/ios/BUILD',
'tensorflow/lite/experimental/objc/BUILD',
'tensorflow/lite/experimental/swift/BUILD'
]
if platform.machine() == 'ppc64le':
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/'
else:
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
# List of files to move when building for iOS.
IOS_FILES = [
'tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec',
'tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec',
]
class UserInputError(Exception):
@ -199,9 +200,10 @@ def setup_python(environ_cp):
ask_python_bin_path = ('Please specify the location of python. [Default is '
'%s]: ') % default_python_bin_path
while True:
python_bin_path = get_from_env_or_user_or_default(
environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path,
default_python_bin_path)
python_bin_path = get_from_env_or_user_or_default(environ_cp,
'PYTHON_BIN_PATH',
ask_python_bin_path,
default_python_bin_path)
# Check if the path is valid
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
break
@ -291,9 +293,9 @@ def get_var(environ_cp,
Args:
environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
query_item: string for feature related to the variable, e.g. "Hadoop File
System".
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
query_item: string for feature related to the variable, e.g. "CUDA for
Nvidia GPUs".
enabled_by_default: boolean for default behavior.
question: optional string for how to ask for user input.
yes_reply: optional string for reply when feature is enabled.
@ -337,8 +339,8 @@ def get_var(environ_cp,
'Environment variable %s must be set as a boolean indicator.\n'
'The following are accepted as TRUE : %s.\n'
'The following are accepted as FALSE: %s.\n'
'Current value is %s.' % (var_name, ', '.join(true_strings),
', '.join(false_strings), var))
'Current value is %s.' %
(var_name, ', '.join(true_strings), ', '.join(false_strings), var))
while var is None:
user_input_origin = get_input(question)
@ -374,9 +376,9 @@ def set_build_var(environ_cp,
Args:
environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
query_item: string for feature related to the variable, e.g. "Hadoop File
System".
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
query_item: string for feature related to the variable, e.g. "CUDA for
Nvidia GPUs".
option_name: string for option to define in .bazelrc.
enabled_by_default: boolean for default behavior.
bazel_config_name: Name for Bazel --config argument to enable build feature.
@ -385,14 +387,14 @@ def set_build_var(environ_cp,
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
environ_cp[var_name] = var
if var == '1':
write_to_bazelrc(
'build:%s --define %s=true' % (bazel_config_name, option_name))
write_to_bazelrc('build:%s --define %s=true' %
(bazel_config_name, option_name))
write_to_bazelrc('build --config=%s' % bazel_config_name)
elif bazel_config_name is not None:
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
# options and not to set build configs through environment variables.
write_to_bazelrc(
'build:%s --define %s=true' % (bazel_config_name, option_name))
write_to_bazelrc('build:%s --define %s=true' %
(bazel_config_name, option_name))
def set_action_env_var(environ_cp,
@ -409,9 +411,9 @@ def set_action_env_var(environ_cp,
Args:
environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
query_item: string for feature related to the variable, e.g. "Hadoop File
System".
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
query_item: string for feature related to the variable, e.g. "CUDA for
Nvidia GPUs".
enabled_by_default: boolean for default behavior.
question: optional string for how to ask for user input.
yes_reply: optional string for reply when feature is enabled.
@ -439,6 +441,9 @@ def convert_version_to_int(version):
"""
version = version.split('-')[0]
version_segments = version.split('.')
# Treat "0.24" as "0.24.0"
if len(version_segments) == 2:
version_segments.append('0')
for seg in version_segments:
if not seg.isdigit():
return None
@ -451,8 +456,8 @@ def check_bazel_version(min_version, max_version):
"""Check installed bazel version is between min_version and max_version.
Args:
min_version: string for minimum bazel version.
max_version: string for maximum bazel version.
min_version: string for minimum bazel version (must exist!).
max_version: string for maximum bazel version (must exist!).
Returns:
The bazel version detected.
@ -565,7 +570,7 @@ def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var,
Args:
environ_cp: copy of the os.environ.
var_name: string for name of environment variable, e.g. "TF_NEED_HDFS".
var_name: string for name of environment variable, e.g. "TF_NEED_CUDA".
ask_for_var: string for how to ask for user input.
var_default: default value string.
@ -658,9 +663,9 @@ def prompt_loop_or_load_from_env(environ_cp,
print(error_msg % val)
environ_cp[var_name] = ''
else:
raise UserInputError(
'Invalid %s setting was provided %d times in a row. '
'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts))
raise UserInputError('Invalid %s setting was provided %d times in a row. '
'Assuming to be a scripting mistake.' %
(var_name, n_ask_attempts))
environ_cp[var_name] = val
return val
@ -669,8 +674,8 @@ def prompt_loop_or_load_from_env(environ_cp,
def create_android_ndk_rule(environ_cp):
"""Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule."""
if is_windows() or is_cygwin():
default_ndk_path = cygpath(
'%s/Android/Sdk/ndk-bundle' % environ_cp['APPDATA'])
default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' %
environ_cp['APPDATA'])
elif is_macos():
default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
else:
@ -689,8 +694,9 @@ def create_android_ndk_rule(environ_cp):
error_msg=('The path %s or its child file "source.properties" '
'does not exist.'))
write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path)
write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL',
check_ndk_level(android_ndk_home_path))
write_action_env_to_bazelrc(
'ANDROID_NDK_API_LEVEL',
get_ndk_api_level(environ_cp, android_ndk_home_path))
def create_android_sdk_rule(environ_cp):
@ -757,8 +763,10 @@ def create_android_sdk_rule(environ_cp):
write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path)
def check_ndk_level(android_ndk_home_path):
"""Check the revision number of an Android NDK path."""
def get_ndk_api_level(environ_cp, android_ndk_home_path):
"""Gets the appropriate NDK API level to use for the provided Android NDK path."""
# First check to see if we're using a blessed version of the NDK.
properties_path = '%s/source.properties' % android_ndk_home_path
if is_windows() or is_cygwin():
properties_path = cygpath(properties_path)
@ -767,16 +775,40 @@ def check_ndk_level(android_ndk_home_path):
revision = re.search(r'Pkg.Revision = (\d+)', filedata)
if revision:
ndk_api_level = revision.group(1)
ndk_version = revision.group(1)
else:
raise Exception('Unable to parse NDK revision.')
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
print('WARNING: The API level of the NDK in %s is %s, which is not '
if int(ndk_version) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
print('WARNING: The NDK version in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing '
'errors.\n' % (android_ndk_home_path, ndk_api_level,
'errors.\n' % (android_ndk_home_path, ndk_version,
_SUPPORTED_ANDROID_NDK_VERSIONS))
return ndk_api_level
# Now grab the NDK API level to use. Note that this is different from the
# SDK API level, as the NDK API level is effectively the *min* target SDK
# version.
platforms = os.path.join(android_ndk_home_path, 'platforms')
api_levels = sorted(os.listdir(platforms))
api_levels = [
x.replace('android-', '') for x in api_levels if 'android-' in x
]
def valid_api_level(api_level):
return os.path.exists(
os.path.join(android_ndk_home_path, 'platforms',
'android-' + api_level))
android_ndk_api_level = prompt_loop_or_load_from_env(
environ_cp,
var_name='ANDROID_NDK_API_LEVEL',
var_default='18', # 18 is required for GPU acceleration.
ask_for_var=('Please specify the (min) Android NDK API level to use. '
'[Available levels: %s]') % api_levels,
check_success=valid_api_level,
error_msg='Android-%s is not present in the NDK path.')
return android_ndk_api_level
def set_gcc_host_compiler_path(environ_cp):
@ -823,149 +855,39 @@ def reformat_version_sequence(version_str, sequence_count):
return '.'.join(v[:sequence_count])
def set_tf_cuda_paths(environ_cp):
"""Set TF_CUDA_PATHS."""
ask_cuda_paths = (
'Please specify the comma-separated list of base paths to look for CUDA '
'libraries and headers. [Leave empty to use the default]: ')
tf_cuda_paths = get_from_env_or_user_or_default(environ_cp, 'TF_CUDA_PATHS',
ask_cuda_paths, '')
if tf_cuda_paths:
environ_cp['TF_CUDA_PATHS'] = tf_cuda_paths
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
"""Set TF_CUDA_VERSION."""
ask_cuda_version = (
'Please specify the CUDA SDK version you want to use. '
'[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
# Configure the Cuda SDK version to use.
tf_cuda_version = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION)
tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2)
# Find out where the CUDA toolkit is installed
default_cuda_path = _DEFAULT_CUDA_PATH
if is_windows() or is_cygwin():
default_cuda_path = cygpath(
environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN))
elif is_linux():
# If the default doesn't exist, try an alternative default.
if (not os.path.exists(default_cuda_path)
) and os.path.exists(_DEFAULT_CUDA_PATH_LINUX):
default_cuda_path = _DEFAULT_CUDA_PATH_LINUX
ask_cuda_path = ('Please specify the location where CUDA %s toolkit is'
' installed. Refer to README.md for more details. '
'[Default is %s]: ') % (tf_cuda_version, default_cuda_path)
cuda_toolkit_path = get_from_env_or_user_or_default(
environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path)
if is_windows() or is_cygwin():
cuda_toolkit_path = cygpath(cuda_toolkit_path)
if is_windows():
cuda_rt_lib_paths = ['lib/x64/cudart.lib']
elif is_linux():
cuda_rt_lib_paths = [
'%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [
'lib64',
'lib/powerpc64le-linux-gnu',
'lib/x86_64-linux-gnu',
]
]
elif is_macos():
cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version]
cuda_toolkit_paths_full = [
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
]
if any(os.path.exists(x) for x in cuda_toolkit_paths_full):
break
# Reset and retry
print('Invalid path to CUDA %s toolkit. %s cannot be found' %
(tf_cuda_version, cuda_toolkit_paths_full))
environ_cp['TF_CUDA_VERSION'] = ''
environ_cp['CUDA_TOOLKIT_PATH'] = ''
else:
raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION
environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path
write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path)
tf_cuda_version = get_from_env_or_user_or_default(environ_cp,
'TF_CUDA_VERSION',
ask_cuda_version,
_DEFAULT_CUDA_VERSION)
environ_cp['TF_CUDA_VERSION'] = tf_cuda_version
write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version)
def set_tf_cudnn_version(environ_cp):
"""Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION."""
"""Set TF_CUDNN_VERSION."""
ask_cudnn_version = (
'Please specify the cuDNN version you want to use. '
'[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
tf_cudnn_version = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version,
_DEFAULT_CUDNN_VERSION)
tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1)
default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '
'installed. Refer to README.md for more details. [Default'
' is %s]: ') % (tf_cudnn_version, default_cudnn_path)
cudnn_install_path = get_from_env_or_user_or_default(
environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
cudnn_install_path = os.path.realpath(
os.path.expanduser(cudnn_install_path))
if is_windows() or is_cygwin():
cudnn_install_path = cygpath(cudnn_install_path)
if is_windows():
cuda_dnn_lib_path = 'lib/x64/cudnn.lib'
cuda_dnn_lib_alt_path = 'lib/x64/cudnn.lib'
elif is_linux():
cuda_dnn_lib_path = 'lib64/libcudnn.so.%s' % tf_cudnn_version
cuda_dnn_lib_alt_path = 'libcudnn.so.%s' % tf_cudnn_version
elif is_macos():
cuda_dnn_lib_path = 'lib/libcudnn.%s.dylib' % tf_cudnn_version
cuda_dnn_lib_alt_path = 'libcudnn.%s.dylib' % tf_cudnn_version
cuda_dnn_lib_path_full = os.path.join(cudnn_install_path, cuda_dnn_lib_path)
cuda_dnn_lib_alt_path_full = os.path.join(cudnn_install_path,
cuda_dnn_lib_alt_path)
if os.path.exists(cuda_dnn_lib_path_full) or os.path.exists(
cuda_dnn_lib_alt_path_full):
break
# Try another alternative for Linux
if is_linux():
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)',
cudnn_path_from_ldconfig)
if cudnn_path_from_ldconfig:
cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
if os.path.exists(
'%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
break
# Reset and Retry
print(
'Invalid path to cuDNN %s toolkit. None of the following files can be '
'found:' % tf_cudnn_version)
print(cuda_dnn_lib_path_full)
print(cuda_dnn_lib_alt_path_full)
if is_linux():
print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version))
environ_cp['TF_CUDNN_VERSION'] = ''
else:
raise UserInputError('Invalid TF_CUDNN setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION
environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path
write_action_env_to_bazelrc('CUDNN_INSTALL_PATH', cudnn_install_path)
tf_cudnn_version = get_from_env_or_user_or_default(environ_cp,
'TF_CUDNN_VERSION',
ask_cudnn_version,
_DEFAULT_CUDNN_VERSION)
environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
@ -997,252 +919,38 @@ def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
return cudnn_ok and cuda_ok
def set_tf_tensorrt_install_path(environ_cp):
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
Adapted from code contributed by Sami Kama (https://github.com/samikama).
Args:
environ_cp: copy of the os.environ.
Raises:
ValueError: if this method was called under non-Linux platform.
UserInputError: if user has provided invalid input multiple times.
"""
def set_tf_tensorrt_version(environ_cp):
"""Set TF_TENSORRT_VERSION."""
if not is_linux():
raise ValueError('Currently TensorRT is only supported on Linux platform.')
# Ask user whether to add TensorRT support.
if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT',
False))) != '1':
if not int(environ_cp.get('TF_NEED_TENSORRT', False)):
return
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
ask_tensorrt_path = (r'Please specify the location where TensorRT is '
'installed. [Default is %s]:') % (
_DEFAULT_TENSORRT_PATH_LINUX)
trt_install_path = get_from_env_or_user_or_default(
environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path,
_DEFAULT_TENSORRT_PATH_LINUX)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path))
def find_libs(search_path):
"""Search for libnvinfer.so in "search_path"."""
fl = set()
if os.path.exists(search_path) and os.path.isdir(search_path):
fl.update([
os.path.realpath(os.path.join(search_path, x))
for x in os.listdir(search_path)
if 'libnvinfer.so' in x
])
return fl
possible_files = find_libs(trt_install_path)
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
highest_ver = [0, None, None]
for lib_file in possible_files:
if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver):
matches = nvinfer_pattern.search(lib_file)
if not matches.groups():
continue
ver_str = matches.group(1)
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
if ver > highest_ver[0]:
highest_ver = [ver, ver_str, lib_file]
if highest_ver[1] is not None:
trt_install_path = os.path.dirname(highest_ver[2])
tf_tensorrt_version = highest_ver[1]
break
# Try another alternative from ldconfig.
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
ldconfig_output = run_shell([ldconfig_bin, '-p'])
search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)',
ldconfig_output)
if search_result:
libnvinfer_path_from_ldconfig = search_result.group(2)
if os.path.exists(libnvinfer_path_from_ldconfig):
if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver,
cudnn_ver):
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
tf_tensorrt_version = search_result.group(1)
break
# Reset and Retry
if possible_files:
print('TensorRT libraries found in one the following directories',
'are not compatible with selected cuda and cudnn installations')
print(trt_install_path)
print(os.path.join(trt_install_path, 'lib'))
print(os.path.join(trt_install_path, 'lib64'))
if search_result:
print(libnvinfer_path_from_ldconfig)
else:
print(
'Invalid path to TensorRT. None of the following files can be found:')
print(trt_install_path)
print(os.path.join(trt_install_path, 'lib'))
print(os.path.join(trt_install_path, 'lib64'))
if search_result:
print(libnvinfer_path_from_ldconfig)
else:
raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION
environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
ask_tensorrt_version = (
'Please specify the TensorRT version you want to use. '
'[Leave empty to default to TensorRT %s]: ') % _DEFAULT_TENSORRT_VERSION
tf_tensorrt_version = get_from_env_or_user_or_default(
environ_cp, 'TF_TENSORRT_VERSION', ask_tensorrt_version,
_DEFAULT_TENSORRT_VERSION)
environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)
def set_tf_nccl_install_path(environ_cp):
"""Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION.
Args:
environ_cp: copy of the os.environ.
Raises:
ValueError: if this method was called under non-Linux platform.
UserInputError: if user has provided invalid input multiple times.
"""
def set_tf_nccl_version(environ_cp):
"""Set TF_NCCL_VERSION."""
if not is_linux():
raise ValueError('Currently NCCL is only supported on Linux platforms.')
raise ValueError('Currently NCCL is only supported on Linux platform.')
if 'TF_NCCL_VERSION' in environ_cp:
return
ask_nccl_version = (
'Please specify the locally installed NCCL version you want to use. '
'[Default is to use https://github.com/nvidia/nccl]: ')
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
tf_nccl_version = get_from_env_or_user_or_default(
environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, '')
if not tf_nccl_version:
break # No need to get install path, building the open source code.
tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1)
# Look with ldconfig first if we can find the library in paths
# like /usr/lib/x86_64-linux-gnu and the header file in the corresponding
# include directory. This is where the NCCL .deb packages install them.
# First check to see if NCCL is in the ldconfig.
# If its found, use that location.
if is_linux():
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)',
nccl2_path_from_ldconfig)
if nccl2_path_from_ldconfig:
nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1)
if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)):
nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig)
print('NCCL libraries found in ' + nccl2_path_from_ldconfig)
# Check if this is the main system lib location
if re.search('.*linux-gnu', nccl_install_path):
trunc_nccl_install_path = '/usr'
print('This looks like a system path.')
else:
trunc_nccl_install_path = nccl_install_path + '/..'
# Look for header
nccl_hdr_path = trunc_nccl_install_path + '/include'
print('Assuming NCCL header path is ' + nccl_hdr_path)
if os.path.exists(nccl_hdr_path + '/nccl.h'):
# Set NCCL_INSTALL_PATH
environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
# Set NCCL_HDR_PATH
environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path
write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path)
break
else:
print(
'The header for NCCL2 cannot be found. Please install the libnccl-dev package.'
)
else:
print('NCCL2 is listed by ldconfig but the library is not found. '
'Your ldconfig is out of date. Please run sudo ldconfig.')
else:
# NCCL is not found in ldconfig. Ask the user for the location.
default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
ask_nccl_path = (
r'Please specify the location where NCCL %s library is '
'installed. Refer to README.md for more details. [Default '
'is %s]:') % (tf_nccl_version, default_nccl_path)
nccl_install_path = get_from_env_or_user_or_default(
environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
nccl_install_path = os.path.realpath(
os.path.expanduser(nccl_install_path))
if is_windows() or is_cygwin():
nccl_install_path = cygpath(nccl_install_path)
nccl_lib_path = ''
if is_windows():
nccl_lib_path = 'lib/x64/nccl.lib'
elif is_linux():
nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version
nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename)
if not os.path.exists(nccl_lpath):
for relative_path in NCCL_LIB_PATHS:
path = '%s/%s%s' % (nccl_install_path, relative_path,
nccl_lib_filename)
if os.path.exists(path):
print('NCCL found at ' + path)
nccl_lib_path = path
break
else:
nccl_lib_path = nccl_lpath
elif is_macos():
nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
nccl_hdr_path = os.path.join(
os.path.dirname(nccl_lib_path), '../include/nccl.h')
print('Assuming NCCL header path is ' + nccl_hdr_path)
if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
# Set NCCL_INSTALL_PATH
environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path)
write_action_env_to_bazelrc('NCCL_INSTALL_PATH',
os.path.dirname(nccl_lib_path))
# Set NCCL_HDR_PATH
environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path)
write_action_env_to_bazelrc('NCCL_HDR_PATH',
os.path.dirname(nccl_hdr_path))
break
# Reset and Retry
print(
'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path,
nccl_hdr_path))
environ_cp['TF_NCCL_VERSION'] = ''
else:
raise UserInputError('Invalid TF_NCCL setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set TF_NCCL_VERSION
'[Leave empty to use http://github.com/nvidia/nccl]: ')
tf_nccl_version = get_from_env_or_user_or_default(environ_cp,
'TF_NCCL_VERSION',
ask_nccl_version, '')
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version)
def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
@ -1305,11 +1013,14 @@ def set_tf_cuda_compute_capabilities(environ_cp):
all_valid = False
else:
ver = float(m.group(0))
if ver < 3.5:
print('ERROR: TensorFlow only supports CUDA compute capabilities 3.5 '
if ver < 3.0:
print('ERROR: TensorFlow only supports CUDA compute capabilities 3.0 '
'and higher. Please re-specify the list of compute '
'capabilities excluding version %s.' % ver)
all_valid = False
if ver < 3.5:
print('WARNING: XLA does not support CUDA compute capabilities '
'lower than 3.5. Disable XLA when running on older GPUs.')
if all_valid:
break
@ -1328,10 +1039,8 @@ def set_other_cuda_vars(environ_cp):
# If CUDA is enabled, always use GPU during build and test.
if environ_cp.get('TF_CUDA_CLANG') == '1':
write_to_bazelrc('build --config=cuda_clang')
write_to_bazelrc('test --config=cuda_clang')
else:
write_to_bazelrc('build --config=cuda')
write_to_bazelrc('test --config=cuda')
def set_host_cxx_compiler(environ_cp):
@ -1495,15 +1204,16 @@ def set_other_mpi_vars(environ_cp):
'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
(mpi_home, mpi_home, mpi_home))
def system_specific_test_config(env):
"""Add default test flags required for TF tests to bazelrc."""
"""Add default build and test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3')
write_to_bazelrc('test --test_size_filters=small,medium')
write_to_bazelrc(
'test --test_tag_filters=-benchmark-test,-no_oss,-oss_serial')
write_to_bazelrc('test --build_tag_filters=-benchmark-test,-no_oss')
if is_windows():
if env.get('TF_NEED_CUDA', None) == 1:
if env.get('TF_NEED_CUDA', None) == '1':
write_to_bazelrc(
'test --test_tag_filters=-no_windows,-no_windows_gpu,-no_gpu')
write_to_bazelrc(
@ -1515,7 +1225,7 @@ def system_specific_test_config(env):
write_to_bazelrc('test --test_tag_filters=-gpu,-nomac,-no_mac')
write_to_bazelrc('test --build_tag_filters=-gpu,-nomac,-no_mac')
elif is_linux():
if env.get('TF_NEED_CUDA', None) == 1:
if env.get('TF_NEED_CUDA', None) == '1':
write_to_bazelrc('test --test_tag_filters=-no_gpu')
write_to_bazelrc('test --build_tag_filters=-no_gpu')
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
@ -1549,7 +1259,8 @@ def set_windows_build_flags(environ_cp):
write_to_bazelrc('build --copt=-w --host_copt=-w')
# Fix winsock2.h conflicts
write_to_bazelrc(
'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN')
'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN '
'--copt=-DNOGDI --host_copt=-DNOGDI')
# Output more verbose information when something goes wrong
write_to_bazelrc('build --verbose_failures')
# The host and target platforms are the same in Windows build. So we don't
@ -1575,26 +1286,90 @@ def config_info_line(name, help_text):
print('\t--config=%-12s\t# %s' % (name, help_text))
def configure_apple_bazel_rules():
"""Configures Bazel rules for building on Apple platforms.
def configure_ios():
"""Configures TensorFlow for iOS builds.
Enables analyzing and building Apple Bazel rules on Apple platforms. This
function will only be executed if `is_macos()` is true.
This function will only be executed if `is_macos()` is true.
"""
if not is_macos():
return
for filepath in APPLE_BAZEL_FILES:
if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000:
print(
'Configuring %s file to analyze and build Bazel rules on Apple platforms.'
% filepath)
'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.')
for filepath in APPLE_BAZEL_FILES:
existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple')
renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath)
os.rename(existing_filepath, renamed_filepath)
symlink_force(existing_filepath, renamed_filepath)
for filepath in IOS_FILES:
filename = os.path.basename(filepath)
new_filepath = os.path.join(_TF_WORKSPACE_ROOT, filename)
symlink_force(filepath, new_filepath)
def validate_cuda_config(environ_cp):
"""Run find_cuda_config.py and return cuda_toolkit_path, or None."""
def maybe_encode_env(env):
"""Encodes unicode in env to str on Windows python 2.x."""
if not is_windows() or sys.version_info[0] != 2:
return env
for k, v in env.items():
if isinstance(k, unicode):
k = k.encode('ascii')
if isinstance(v, unicode):
v = v.encode('ascii')
env[k] = v
return env
cuda_libraries = ['cuda', 'cudnn']
if is_linux():
if int(environ_cp.get('TF_NEED_TENSORRT', False)):
cuda_libraries.append('tensorrt')
if environ_cp.get('TF_NCCL_VERSION', None):
cuda_libraries.append('nccl')
proc = subprocess.Popen(
[environ_cp['PYTHON_BIN_PATH'], 'third_party/gpus/find_cuda_config.py'] +
cuda_libraries,
stdout=subprocess.PIPE,
env=maybe_encode_env(environ_cp))
if proc.wait():
# Errors from find_cuda_config.py were sent to stderr.
print('Asking for detailed CUDA configuration...\n')
return False
config = dict(
tuple(line.decode('ascii').rstrip().split(': ')) for line in proc.stdout)
print('Found CUDA %s in:' % config['cuda_version'])
print(' %s' % config['cuda_library_dir'])
print(' %s' % config['cuda_include_dir'])
print('Found cuDNN %s in:' % config['cudnn_version'])
print(' %s' % config['cudnn_library_dir'])
print(' %s' % config['cudnn_include_dir'])
if 'tensorrt_version' in config:
print('Found TensorRT %s in:' % config['tensorrt_version'])
print(' %s' % config['tensorrt_library_dir'])
print(' %s' % config['tensorrt_include_dir'])
if config.get('nccl_version', None):
print('Found NCCL %s in:' % config['nccl_version'])
print(' %s' % config['nccl_library_dir'])
print(' %s' % config['nccl_include_dir'])
print('\n')
environ_cp['CUDA_TOOLKIT_PATH'] = config['cuda_toolkit_path']
return True
def main():
global _TF_WORKSPACE_ROOT
global _TF_BAZELRC
global _TF_CURRENT_BAZEL_VERSION
parser = argparse.ArgumentParser()
parser.add_argument(
@ -1611,7 +1386,8 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
check_bazel_version('0.19.0', '0.23.0')
current_bazel_version = check_bazel_version('0.24.1', '0.25.2')
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc()
@ -1633,7 +1409,7 @@ def main():
if is_macos():
environ_cp['TF_NEED_TENSORRT'] = '0'
else:
environ_cp['TF_CONFIGURE_APPLE_BAZEL_RULES'] = '0'
environ_cp['TF_CONFIGURE_IOS'] = '0'
# The numpy package on ppc64le uses OpenBLAS which has multi-threading
# issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at
@ -1666,11 +1442,43 @@ def main():
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_tf_cuda_version(environ_cp)
set_tf_cudnn_version(environ_cp)
if is_linux():
set_tf_tensorrt_install_path(environ_cp)
set_tf_nccl_install_path(environ_cp)
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
environ_save = dict(environ_cp)
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
if validate_cuda_config(environ_cp):
cuda_env_names = [
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
# Items below are for backwards compatibility when not using
# TF_CUDA_PATHS.
'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH',
'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH'
]
# Note: set_action_env_var above already writes to bazelrc.
for name in cuda_env_names:
if name in environ_cp:
write_action_env_to_bazelrc(name, environ_cp[name])
break
# Restore settings changed below if CUDA config could not be validated.
environ_cp = dict(environ_save)
set_tf_cuda_version(environ_cp)
set_tf_cudnn_version(environ_cp)
if is_linux():
set_tf_tensorrt_version(environ_cp)
set_tf_nccl_version(environ_cp)
set_tf_cuda_paths(environ_cp)
else:
raise UserInputError(
'Invalid CUDA setting were provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
set_tf_cuda_compute_capabilities(environ_cp)
if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get(
@ -1688,7 +1496,6 @@ def main():
else:
# Use downloaded LLD for linking.
write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld')
write_to_bazelrc('test:cuda_clang --config=download_clang_use_lld')
else:
# Set up which gcc nvcc should use as the host compiler
# No need to set this on Windows
@ -1701,7 +1508,6 @@ def main():
set_tf_download_clang(environ_cp)
if environ_cp.get('TF_DOWNLOAD_CLANG') == '1':
write_to_bazelrc('build --config=download_clang')
write_to_bazelrc('test --config=download_clang')
# SYCL / ROCm / CUDA are mutually exclusive.
# At most 1 GPU platform can be configured.
@ -1738,13 +1544,9 @@ def main():
system_specific_test_config(os.environ)
if get_var(
environ_cp, 'TF_CONFIGURE_APPLE_BAZEL_RULES',
'Configure Bazel rules for Apple platforms', False,
('Would you like to configure Bazel rules for building on Apple platforms?'
), 'Configuring Bazel rules for Apple platforms.',
'Not configuring Bazel rules for Apple platforms.'):
configure_apple_bazel_rules()
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
if environ_cp.get('TF_CONFIGURE_IOS') == '1':
configure_ios()
print('Preconfigured Bazel build configs. You can use any of the below by '
'adding "--config=<>" to your build command. See .bazelrc for more '

View File

@ -15,6 +15,7 @@ exports_files([
"leakr_file_type_recipe.ftrcp",
])
load("//tensorflow:tensorflow.bzl", "VERSION")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl")
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
@ -163,7 +164,7 @@ config_setting(
name = "macos",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_x86_64",
"cpu": "darwin",
},
visibility = ["//visibility:public"],
)
@ -183,6 +184,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_aarch64",
values = {"cpu": "aarch64"},
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_x86_64",
values = {"cpu": "k8"},
@ -325,6 +332,18 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_with_framework_shared_object",
define_values = {
"framework_shared_object": "true",
},
values = {
"apple_platform_type": "macos",
"cpu": "darwin",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "using_cuda_clang",
define_values = {
@ -407,9 +426,15 @@ config_setting(
values = {"cpu": "x64_windows"},
)
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group(
name = "internal",
packages = ["//tensorflow/..."],
packages = [
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
],
)
load(
@ -467,7 +492,7 @@ cc_library(
# projects building with Bazel and importing TensorFlow as a dependency will not
# depend on libtensorflow_framework.so unless they opt in.
tf_cc_shared_object(
name = "libtensorflow_framework.so",
name = "tensorflow_framework",
framework_so = [],
linkopts = select({
"//tensorflow:macos": [],
@ -477,8 +502,11 @@ tf_cc_shared_object(
],
}),
linkstatic = 1,
per_os_targets = True,
soversion = VERSION,
visibility = ["//visibility:public"],
deps = [
"//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core:core_cpu_impl",
"//tensorflow/core:framework_internal_impl",
"//tensorflow/core:gpu_runtime_impl",
@ -508,7 +536,6 @@ tf_cc_shared_object(
linkopts = select({
"//tensorflow:macos": [
"-Wl,-exported_symbols_list,$(location //tensorflow/c:exported_symbols.lds)",
"-Wl,-install_name,@rpath/libtensorflow.so",
],
"//tensorflow:windows": [
],
@ -518,6 +545,7 @@ tf_cc_shared_object(
],
}),
per_os_targets = True,
soversion = VERSION,
visibility = ["//visibility:public"],
# add win_def_file for tensorflow
win_def_file = select({
@ -548,6 +576,7 @@ tf_cc_shared_object(
],
}),
per_os_targets = True,
soversion = VERSION,
visibility = ["//visibility:public"],
# add win_def_file for tensorflow_cc
win_def_file = select({

View File

@ -12,7 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in all of the public TensorFlow interface into this module."""
"""
Top-level module of TensorFlow. By convention, we refer to this module as
`tf` instead of `tensorflow`, following the common practice of importing
TensorFlow via the command `import tensorflow as tf`.
The primary function of this module is to import all of the public TensorFlow
interfaces into a single place. The interfaces themselves are located in
sub-modules, as described below.
Note that the file `__init__.py` in the TensorFlow source code tree is actually
only a placeholder to enable test cases to run. The TensorFlow build replaces
this file with a file generated from [`api_template.__init__.py`](https://www.github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py)
"""
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
@ -20,10 +32,13 @@ from __future__ import print_function as _print_function
import distutils as _distutils
import inspect as _inspect
import logging as _logging
import os as _os
import site as _site
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER
# Make sure directory containing top level submodules is in
@ -37,25 +52,29 @@ if not hasattr(_current_module, '__path__'):
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# pylint: disable=g-bad-import-order
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorboard.summary._tf.summary'),
error_msg="Limited tf.summary API due to missing TensorBoard installation")
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api._v2.estimator'))
# Hook external TensorFlow modules.
try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
except ImportError:
_logging.warning(
"Limited tf.summary API due to missing TensorBoard installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
if not hasattr(_current_module, 'estimator'):
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api.estimator'))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v2.keras'))
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top

View File

@ -26,30 +26,44 @@ import sys as _sys
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api._v1.estimator'))
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
_API_MODULE = bitwise # pylint: disable=undefined-variable
_current_module = _sys.modules[__name__]
if not hasattr(_current_module, 'estimator'):
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api.estimator'))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v1.keras'))
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Hook external TensorFlow modules.
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """
WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
* https://github.com/tensorflow/addons
* https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.
"""
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib',
@ -65,17 +79,6 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-
# The 'app' module will be imported as part of the placeholder section above.
app.flags = flags # pylint: disable=undefined-variable
# Also use 'app' module (choice is arbitrary) to derive the API directory below.
_API_MODULE = app # pylint: disable=undefined-variable
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
@ -117,7 +120,11 @@ if _running_from_pip_package():
# pylint: disable=undefined-variable
try:
del python
if '__all__' in vars():
vars()['__all__'].remove('python')
del core
if '__all__' in vars():
vars()['__all__'].remove('core')
except NameError:
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
@ -128,6 +135,8 @@ except NameError:
# others don't exist.
try:
del compiler
if '__all__' in vars():
vars()['__all__'].remove('compiler')
except NameError:
pass
# pylint: enable=undefined-variable

View File

@ -21,6 +21,7 @@ filegroup(
srcs = [
"c_api.h",
"c_api_experimental.h",
"tf_attrtype.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
@ -39,14 +40,19 @@ filegroup(
"python_api.h",
"*test*",
],
),
) + [
"//tensorflow/cc:srcs",
"//tensorflow/core/distributed_runtime:server_lib.h",
],
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api.h"],
hdrs = ["c_api_internal.h"],
hdrs = [
"c_api.h",
"c_api_internal.h",
],
visibility = [
"//tensorflow:internal",
"//tensorflow/c:__subpackages__",
@ -56,6 +62,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tf_attrtype",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -66,14 +73,24 @@ tf_cuda_library(
}),
)
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "c_api",
hdrs = ["c_api.h"],
hdrs = [
"c_api.h",
"tf_attrtype.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":c_api_no_xla",
":c_api_internal",
":tf_attrtype",
] + select({
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
@ -89,16 +106,18 @@ tf_cuda_library(
"c_api.cc",
"c_api_function.cc",
],
hdrs = [
"c_api.h",
],
hdrs = ["c_api.h"],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [":c_api_internal"] + select({
deps = [
":c_api_internal",
":tf_attrtype",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"@com_google_absl//absl/strings",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients",
"//tensorflow/cc:ops",
@ -140,19 +159,11 @@ tf_cuda_library(
"//tensorflow/core:lib_platform",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "c_api_headers",
hdrs = [
"c_api.h",
],
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
)
exports_files(
[
"version_script.lds",
@ -238,6 +249,28 @@ tf_cuda_library(
}),
)
tf_cuda_library(
name = "ops",
srcs = [
"ops.cc",
],
hdrs = [
"ops.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":tf_status_helper",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
],
}) + [":c_api_internal"],
)
# -----------------------------------------------------------------------------
# Tests
@ -286,7 +319,6 @@ tf_cuda_cc_test(
"//conditions:default": [],
}),
tags = [
"no_oss", # http://b/119522529
"noasan",
],
# We must ensure that the dependencies can be dynamically linked since
@ -440,6 +472,27 @@ tf_cuda_cc_test(
],
)
tf_cc_test(
name = "ops_test",
size = "small",
srcs = ["ops_test.cc"],
linkopts = select({
"//conditions:default": [],
}),
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
# -----------------------------------------------------------------------------
# Python API target

View File

@ -30,8 +30,8 @@ limitations under the License.
#include "tensorflow/cc/ops/while_loop.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/framework/logging.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/kernels/logging_ops.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@ -368,7 +368,7 @@ static Status TF_StringDecode_Impl(const char* src, size_t src_len,
size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
size_t* dst_len, TF_Status* status) {
status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
if (!status->status.ok()) return 0;
if (TF_GetCode(status) != TF_OK) return 0;
return static_cast<size_t>(*dst - src) + *dst_len;
}
@ -423,7 +423,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
status->status = NewSession(opt->options, &session);
if (status->status.ok()) {
if (TF_GetCode(status) == TF_OK) {
return new TF_DeprecatedSession({session});
} else {
DCHECK_EQ(nullptr, session);
@ -615,7 +615,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
offsets++;
const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
status->status = InvalidArgument(
"invalid string tensor encoding (string #", i, " of ",
srcarray.size(), "): ", status->status.error_message());
@ -775,7 +775,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
// TODO(nolivia): check this on a subset of the graph instead of all of
// it.
status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
session->graph->mu.unlock();
return false;
}
@ -795,7 +795,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
*graph_def.mutable_library() = graph.flib_def().ToProto();
session->graph->mu.unlock();
status->status = session->session->Extend(graph_def);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
// Contract is we always delete input_values[i].
return false;
}
@ -825,7 +825,7 @@ static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
const int ninputs = input_pairs->size();
for (int i = 0; i < ninputs; ++i) {
status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
if (!status->status.ok()) return false;
if (TF_GetCode(status) != TF_OK) return false;
}
return true;
}
@ -863,7 +863,7 @@ static void TF_Run_Helper(
// Serialize back to upstream client, who now owns the new buffer
if (run_metadata != nullptr) {
status->status = MessageToBuffer(run_metadata_proto, run_metadata);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
}
} else {
// NOTE(zongheng): PRun does not support RunOptions yet.
@ -883,7 +883,7 @@ static void TF_Run_Helper(
continue;
}
c_outputs[i] = TF_TensorFromTensor(src, status);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
}
}
@ -940,7 +940,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
string new_handle;
status->status = s->session->PRunSetup(input_names, output_names,
target_oper_names, &new_handle);
if (status->status.ok()) {
if (TF_GetCode(status) == TF_OK) {
char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
@ -979,7 +979,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
status->status = tensorflow::LoadLibrary(
library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
&lib_handle->op_list.length);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
delete lib_handle;
return nullptr;
}
@ -1009,7 +1009,7 @@ TF_Buffer* TF_GetAllOpList() {
// --------------------------------------------------------------------------
// ListDevices & SessionListDevices API
void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList;
@ -1407,7 +1407,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
TF_Tensor* value, TF_Status* status) {
Tensor t;
status->status = TF_TensorToTensor(value, &t);
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
}
void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
@ -1417,13 +1417,13 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
std::vector<Tensor> t;
t.reserve(num_values);
for (int i = 0; i < num_values && status->status.ok(); ++i) {
for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) {
Tensor v;
status->status = TF_TensorToTensor(values[i], &v);
t.emplace_back(v);
}
if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
}
void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
@ -1471,11 +1471,11 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
}
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
if (status->status.ok()) {
if (TF_GetCode(status) == TF_OK) {
// Run shape inference function for newly added node.
status->status = desc->graph->refiner.AddNode(ret);
}
if (status->status.ok()) {
if (TF_GetCode(status) == TF_OK) {
// Add the node to the name-to-node mapping.
desc->graph->name_map[ret->name()] = ret;
} else if (ret != nullptr) {
@ -1524,10 +1524,10 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
NameRangeMap name_ranges;
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
if (!status->status.ok()) return -1;
if (TF_GetCode(status) != TF_OK) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
status->status = InvalidArgument("Input arg '", arg_name, "' not found");
status->status = InvalidArgument("Output arg '", arg_name, "' not found");
return -1;
}
return iter->second.second - iter->second.first;
@ -1546,7 +1546,7 @@ int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
NameRangeMap name_ranges;
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
if (!status->status.ok()) return -1;
if (TF_GetCode(status) != TF_OK) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
status->status = InvalidArgument("Input arg '", arg_name, "' not found");
@ -1644,7 +1644,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
TF_Status* status) {
TF_AttrMetadata metadata;
const auto* attr = GetAttrValue(oper, attr_name, status);
if (!status->status.ok()) return metadata;
if (TF_GetCode(status) != TF_OK) return metadata;
switch (attr->value_case()) {
#define SINGLE_CASE(kK, attr_type, size_expr) \
case tensorflow::AttrValue::kK: \
@ -1751,7 +1751,7 @@ void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
void* value, size_t max_length,
TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
if (attr->value_case() != tensorflow::AttrValue::kS) {
status->status =
InvalidArgument("Attribute '", attr_name, "' is not a string");
@ -1769,7 +1769,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
int max_values, void* storage,
size_t storage_size, TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
if (attr->value_case() != tensorflow::AttrValue::kList) {
status->status =
InvalidArgument("Value for '", attr_name, "' is not a list");
@ -1802,7 +1802,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
int max_values, TF_Status* status) { \
const auto* attr = GetAttrValue(oper, attr_name, status); \
if (!status->status.ok()) return; \
if (TF_GetCode(status) != TF_OK) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
@ -1824,7 +1824,7 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
PartialTensorShape shape;
status->status =
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
auto len = std::min(shape.dims(), num_dims);
for (int i = 0; i < len; ++i) {
value[i] = shape.dim_size(i);
@ -1832,21 +1832,21 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
}
void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
int64_t** values, int* num_dims,
int max_values, int64_t* storage,
int storage_size, TF_Status* status) {
int64_t** dims, int* num_dims, int num_shapes,
int64_t* storage, int storage_size,
TF_Status* status) {
std::vector<PartialTensorShape> shapes;
status->status =
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
if (!status->status.ok()) return;
auto len = std::min(static_cast<int>(shapes.size()), max_values);
if (TF_GetCode(status) != TF_OK) return;
auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
int64_t* p = storage;
int storage_left = storage_size;
for (int i = 0; i < len; ++i) {
// shapes[i].dims() == -1 for shapes with an unknown rank.
int64_t n = shapes[i].dims();
num_dims[i] = n;
values[i] = p;
dims[i] = p;
if (n < 0) {
continue;
}
@ -1866,7 +1866,7 @@ void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
const char* attr_name,
TF_Buffer* value, TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
if (attr->value_case() != tensorflow::AttrValue::kShape) {
status->status =
InvalidArgument("Value for '", attr_name, "' is not a shape.");
@ -1880,7 +1880,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
TF_Buffer** values, int max_values,
TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
if (attr->value_case() != tensorflow::AttrValue::kList) {
status->status =
InvalidArgument("Value for '", attr_name, "' is not a list");
@ -1890,7 +1890,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
for (int i = 0; i < len; ++i) {
values[i] = TF_NewBuffer();
status->status = MessageToBuffer(attr->list().shape(i), values[i]);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
// Delete everything allocated to far, the operation has failed.
for (int j = 0; j <= i; ++j) {
TF_DeleteBuffer(values[j]);
@ -1905,7 +1905,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
*value = nullptr;
Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
*value = TF_TensorFromTensor(t, status);
}
@ -1914,7 +1914,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
TF_Status* status) {
std::vector<Tensor> ts;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) {
values[i] = TF_TensorFromTensor(ts[i], status);
@ -1925,7 +1925,7 @@ void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
TF_Buffer* output_attr_value,
TF_Status* status) {
const auto* attr = GetAttrValue(oper, attr_name, status);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
status->status = MessageToBuffer(*attr, output_attr_value);
}
@ -1941,7 +1941,10 @@ TF_Graph::TF_Graph()
refiner(graph.versions().producer(), graph.op_registry()),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {}
parent_inputs(nullptr) {
// Tell the shape refiner to also run shape inference on functions.
refiner.set_function_library_for_shape_inference(&graph.flib_def());
}
TF_Graph* TF_NewGraph() { return new TF_Graph; }
@ -2003,7 +2006,7 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
{
mutex_lock l(graph->mu);
status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
}
status->status = MessageToBuffer(*op_def, output_op_def);
}
@ -2121,7 +2124,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
&graph->refiner, &results);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
// Add new nodes to name_map
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
@ -2175,7 +2178,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
auto results = new TF_ImportGraphDefResults();
mutex_lock l(graph->mu);
GraphImportGraphDefLocked(graph, def, options, results, status);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
delete results;
return nullptr;
}
@ -2233,7 +2236,7 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
// TODO(skyewm): set placeholder shape
TF_Operation* oper = TF_FinishOperation(desc, status);
if (!status->status.ok()) return false;
if (TF_GetCode(status) != TF_OK) return false;
*input = {oper, 0};
return true;
}
@ -2378,7 +2381,7 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
body_graph, body_inputs, body_outputs, name};
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
FreeWhileResources(&params);
return EmptyWhileParams();
}
@ -2582,7 +2585,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
status->status = NewSession(opt->options, &session);
if (status->status.ok()) {
if (TF_GetCode(status) == TF_OK) {
TF_Session* new_session = new TF_Session(session, graph);
if (graph != nullptr) {
mutex_lock l(graph->mu);
@ -2630,7 +2633,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
status->status =
tensorflow::LoadSavedModel(session_options->options, run_options_proto,
export_dir, tag_set, &bundle);
if (!status->status.ok()) return nullptr;
if (TF_GetCode(status) != TF_OK) return nullptr;
// Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
// extends using GraphDefs. The Graph instance is different, but equivalent
@ -2647,7 +2650,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
if (meta_graph_def != nullptr) {
status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
if (!status->status.ok()) return nullptr;
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TF_Session* session = new TF_Session(bundle.session.release(), graph);
@ -2747,7 +2750,7 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
string new_handle;
status->status = session->session->PRunSetup(input_names, output_names,
target_names, &new_handle);
if (status->status.ok()) {
if (TF_GetCode(status) == TF_OK) {
char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
@ -2809,9 +2812,9 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
tensor, graph->refiner, *graph->graph.op_registry(),
graph->graph.versions().producer(), &evaluated, &result_tensor);
if (evaluated) {
DCHECK(status->status.ok());
DCHECK(TF_GetCode(status) == TF_OK);
*result = TF_TensorFromTensor(result_tensor, status);
if (!status->status.ok()) evaluated = false;
if (TF_GetCode(status) != TF_OK) evaluated = false;
}
return evaluated;
}
@ -2866,7 +2869,7 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(*api_def, ret);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
TF_DeleteBuffer(ret);
return nullptr;
}
@ -2878,7 +2881,7 @@ TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(kernel_list, ret);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
TF_DeleteBuffer(ret);
return nullptr;
}
@ -2890,7 +2893,7 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
tensorflow::GetRegisteredKernelsForOp(name);
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(kernel_list, ret);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
TF_DeleteBuffer(ret);
return nullptr;
}
@ -2920,7 +2923,7 @@ TF_Server* TF_NewServer(const void* proto, size_t proto_len,
std::unique_ptr<tensorflow::ServerInterface> out_server;
status->status = tensorflow::NewServer(server_def, &out_server);
if (!status->status.ok()) return nullptr;
if (TF_GetCode(status) != TF_OK) return nullptr;
return new TF_Server(std::move(out_server));
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)

View File

@ -19,6 +19,8 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/c/tf_attrtype.h"
// --------------------------------------------------------------------------
// C API for TensorFlow.
//
@ -686,19 +688,6 @@ TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs(
TF_Operation* oper, TF_Operation** control_outputs,
int max_control_outputs);
// TF_AttrType describes the type of the value of an attribute on an operation.
typedef enum TF_AttrType {
TF_ATTR_STRING = 0,
TF_ATTR_INT = 1,
TF_ATTR_FLOAT = 2,
TF_ATTR_BOOL = 3,
TF_ATTR_TYPE = 4,
TF_ATTR_SHAPE = 5,
TF_ATTR_TENSOR = 6,
TF_ATTR_PLACEHOLDER = 7,
TF_ATTR_FUNC = 8,
} TF_AttrType;
// TF_AttrMetadata describes the value of an attribute on an operation.
typedef struct TF_AttrMetadata {
// A boolean: 1 if the attribute value is a list, 0 otherwise.

File diff suppressed because it is too large Load Diff

View File

@ -62,6 +62,20 @@ extern "C" {
TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
unsigned char enable);
// Set XLA's internal BuildXlaOpsPassFlags.tf_xla_enable_lazy_compilation to the
// value of 'enabled'. Also returns the original value of that flag.
//
// Use in tests to allow XLA to fallback to TF classic. This has global effect.
TF_CAPI_EXPORT unsigned char TF_SetXlaEnableLazyCompilation(
unsigned char enable);
// Sets XLA's auto jit mode according to the specified string, which is parsed
// as if passed in XLA_FLAGS. This has global effect.
TF_CAPI_EXPORT void TF_SetXLaAutoJitMode(const char* mode);
// Sets XLA's minimum cluster size. This has global effect.
TF_CAPI_EXPORT void TF_SetXlaMinClusterSize(int size);
// Create a serialized tensorflow.ConfigProto proto, where:
//
// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if
@ -93,26 +107,6 @@ TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph,
TF_CAPI_EXPORT extern char* TF_FunctionDebugString(TF_Function* func,
size_t* len);
// Creates a stack of data set + iterator nodes, currently hard-coded to return
// a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success,
// returns the IteratorGetNext node, which caller can run or feed into an node.
//
// TODO(hongm): Extend the API to allow customization of the nodes created.
TF_CAPI_EXPORT extern TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(
TF_Graph* graph, TF_Status* status);
// Similar to the above API, except that the returned iterator reads the
// file based dataset from `file_path`.
// If `is_mnist` is 0, the dataset corresponds to ImageNet.
// The iterators outputs 2 tensors:
// - A float tensor of shape `batch_size` X 784 when `is_mnist` is non-zero, or
// `batch_size` X 224 X 224 X 3 otherwise.
// - An int32 tensor of shape `batch_size`
// TODO(hongm): Extend the API to allow customization of the nodes created.
TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets(
TF_Graph* graph, const char* file_path, int batch_size,
unsigned char is_mnist, TF_Status* status);
// On success, dequeues a tensor from a TF-managed FifoQueue given by
// `tensor_id`, associated with `session`. There must be a graph node named
// "fifo_queue_dequeue_<tensor_id>", to be executed by this API call.

View File

@ -27,100 +27,6 @@ limitations under the License.
namespace tensorflow {
namespace {
void TestFakeIteratorStack() {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
CSession csession(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Run the graph.
const float base_value = 42.0;
for (int i = 0; i < 3; ++i) {
csession.SetOutputs({get_next});
csession.Run(s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
ASSERT_EQ(TF_FLOAT, TF_TensorType(out));
ASSERT_EQ(0, TF_NumDims(out)); // scalar
ASSERT_EQ(sizeof(float), TF_TensorByteSize(out));
float* output_contents = static_cast<float*>(TF_TensorData(out));
ASSERT_EQ(base_value + i, *output_contents);
}
// This should error out since we've exhausted the iterator.
csession.Run(s);
ASSERT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)) << TF_Message(s);
// Clean up
csession.CloseAndDelete(s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); }
TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
const string file_path = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record");
VLOG(1) << "data file path is " << file_path;
const int batch_size = 64;
TF_Operation* get_next = TF_MakeFileBasedIteratorGetNextWithDatasets(
graph, file_path.c_str(), batch_size, /*is_mnist*/ false, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
CSession csession(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Run the graph.
// The two output tensors should look like:
// Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32)
// Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32)
for (int i = 0; i < 3; ++i) {
LOG(INFO) << "Running iter " << i;
csession.SetOutputs({{get_next, 0}, {get_next, 1}});
csession.Run(s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
{
TF_Tensor* image = csession.output_tensor(0);
ASSERT_TRUE(image != nullptr);
ASSERT_EQ(TF_FLOAT, TF_TensorType(image));
// Confirm shape is 224 X 224 X 3
ASSERT_EQ(4, TF_NumDims(image));
ASSERT_EQ(batch_size, TF_Dim(image, 0));
ASSERT_EQ(224, TF_Dim(image, 1));
ASSERT_EQ(224, TF_Dim(image, 2));
ASSERT_EQ(3, TF_Dim(image, 3));
ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3,
TF_TensorByteSize(image));
}
{
TF_Tensor* label = csession.output_tensor(1);
ASSERT_TRUE(label != nullptr);
ASSERT_EQ(TF_INT32, TF_TensorType(label));
ASSERT_EQ(1, TF_NumDims(label));
ASSERT_EQ(batch_size, TF_Dim(label, 0));
ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label));
}
}
// Clean up
csession.CloseAndDelete(s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
TEST(CAPI_EXPERIMENTAL, GetServerDefTest) {
const string expected_text_proto(R"(cluster {
job {
@ -470,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
TFE_DeleteOp(identityn);
}
TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) {
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
TFE_OpSetAttrInt(concatv2, "N", 2);
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
constexpr size_t kNumInputs = 2;
for (size_t i = 0; i < kNumInputs; ++i) {
TFE_OpAddInput(concatv2, matrix, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
}
TFE_OpAddInput(concatv2, axis, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
AddEagerOpToGraphAndCheck(
concatv2, [this, kNumInputs](TF_Operation* graph_op) {
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1);
int64_t attrN;
TF_OperationGetAttrInt(graph_op, "N", &attrN, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
EXPECT_EQ(attrN, kNumInputs);
EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_),
kNumInputs);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
});
TFE_DeleteTensorHandle(axis);
TFE_DeleteTensorHandle(matrix);
TFE_DeleteOp(concatv2);
}
TEST_F(AddEagerOpToGraphTest,
GeneratesInternalErrorsForInvalidNumberAttributes) {
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
int num_retvals = 5;
TFE_TensorHandle* retvals[5];
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
TFE_OpSetAttrInt(concatv2, "N", -1);
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals,
&num_retvals, status_);
EXPECT_EQ(graph_op, nullptr);
EXPECT_EQ(status_->status.error_message(),
"Number attribute for length should be >=0!");
TFE_DeleteOp(concatv2);
TFE_DeleteTensorHandle(axis);
TFE_DeleteTensorHandle(matrix);
}
} // namespace
} // namespace tensorflow

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api_internal.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -295,7 +295,8 @@ Status FillFunctionBody(
}
// Graph to FunctionDef conversion. This code is closely modeled on the Python
// code in tensorflow/python/framework/function.py.
// function graph_to_function_def(), which is located in
// tensorflow/python/framework/graph_to_function_def.py.
Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
bool append_hash_to_fn_name,
const std::vector<const Node*>& body_nodes,
@ -352,6 +353,16 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
argdef->set_type(node->output_type(idx));
const string& input_name = node_names.GetInputName(node->name());
argdef->set_name(input_name);
auto& arg_attrs = (*fdef->mutable_arg_attr())[i];
for (const auto& attr : node->attrs()) {
// Only copy internal attributes. These attributes will be applied to
// _Arg/Placeholder nodes when this FunctionDef is converted to graph, and
// normal attributes for nodes cannot be applied to those _Arg/Placeholder
// nodes.
if (absl::StartsWith(attr.first, "_")) {
arg_attrs.mutable_attr()->insert(attr);
}
}
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
}
@ -442,12 +453,21 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
} else {
signature_name = control_outputs[i]->name();
}
if (signature_name.empty()) {
return errors::InvalidArgument("Control output name must be not empty");
}
if (!control_output_names_set.insert(signature_name).second) {
return errors::InvalidArgument("Repeated control output name: ",
signature_name);
}
const string control_output_node =
node_names.Lookup(control_outputs[i]->name());
if (control_output_node.empty()) {
return errors::InvalidArgument(
"Control output node name must be not empty");
}
fdef->mutable_signature()->add_control_output(signature_name);
(*fdef->mutable_control_ret())[signature_name] = control_outputs[i]->name();
(*fdef->mutable_control_ret())[signature_name] = control_output_node;
}
return Status::OK();
@ -572,13 +592,13 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
std::unordered_map<const Node*, std::vector<int>> input_nodes;
status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
&input_tensors, &input_nodes);
if (!status->status.ok()) return nullptr;
if (TF_GetCode(status) != TF_OK) return nullptr;
// Process outputs.
std::vector<tensorflow::OutputTensor> output_tensors;
status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
outputs, &output_tensors);
if (!status->status.ok()) return nullptr;
if (TF_GetCode(status) != TF_OK) return nullptr;
// Process output names.
std::vector<string> output_names_vec;
@ -602,7 +622,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
std::vector<const Node*> body_nodes;
status->status = tensorflow::ComputeBodyNodes(
fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
if (!status->status.ok()) return nullptr;
if (TF_GetCode(status) != TF_OK) return nullptr;
// Compute body nodes.
std::vector<const Node*> control_output_nodes;
@ -617,7 +637,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes,
input_tensors, output_tensors, output_names_vec, control_output_nodes,
control_output_names_vec, description, &tf_function->fdef);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
TF_DeleteFunction(tf_function);
return nullptr;
}

View File

@ -1278,6 +1278,46 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
}
void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name,
const char* attr_name, const char* attr_value,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value));
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
TF_NewGraph(), TF_DeleteGraph);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
TF_DeleteStatus);
TF_Operation* node;
NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value",
&node);
TF_Output inputs[] = {{node, 0}};
TF_Output outputs[] = {};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 1, inputs, 0, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
ASSERT_NE(func_, nullptr);
// Verify that FunctionDef ArgDef has attributes.
ASSERT_EQ(func_->fdef.arg_attr_size(), 1);
auto arg_attrs = func_->fdef.arg_attr().find(0);
ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end());
auto iter = arg_attrs->second.attr().find("_test_attr");
ASSERT_NE(iter, arg_attrs->second.attr().end());
EXPECT_EQ(iter->second.s(), "value");
}
TEST_F(CApiFunctionTest, SetGradientAndRun) {
// Define the function and its grad
DefineFunction(func_name_, &func_);

View File

@ -24,8 +24,10 @@ limitations under the License.
#include <unordered_map>
#include <vector>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h" // NO_LINT
#include "tensorflow/core/platform/platform.h"
// clang-format on
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/core/framework/op_gen_lib.h"

View File

@ -29,8 +29,7 @@ namespace checkpoint {
class TensorSliceReader;
CheckpointReader::CheckpointReader(const string& filename,
TF_Status* out_status)
CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
: reader_(nullptr),
v2_reader_(nullptr),
var_to_shape_map_(nullptr),
@ -43,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename,
v2_reader_.reset(
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
if (!v2_reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, v2_reader_->status());
Set_TF_Status_from_Status(status, v2_reader_->status());
return;
}
auto result = BuildV2VarMaps();
@ -52,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename,
} else {
reader_.reset(new TensorSliceReader(filename));
if (!reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, reader_->status());
Set_TF_Status_from_Status(status, reader_->status());
return;
}
var_to_shape_map_.reset(

View File

@ -39,7 +39,7 @@ class TensorSliceReader;
// variables.
class CheckpointReader {
public:
CheckpointReader(const string& filepattern, TF_Status* out_status);
CheckpointReader(const string& filename, TF_Status* status);
bool HasTensor(const string& name) const;
const string DebugString() const;

View File

@ -1,4 +1,5 @@
# Experimental extensions to the C API for eager execution of kernels.
licenses(["notice"]) # Apache 2.0
load(
@ -70,6 +71,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/lib:profiler_eager_lib",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core:gpu_runtime",
],
@ -110,6 +112,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/profiler/lib:profiler_eager_lib",
"//tensorflow/core/profiler/lib:profiler_session",
],
)
@ -200,6 +203,7 @@ tf_cuda_library(
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@ -236,7 +240,6 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler:protos_all_cc",
"@com_google_absl//absl/strings",
],
)
@ -256,3 +259,22 @@ filegroup(
srcs = ["c_api.h"],
visibility = ["//tensorflow:__subpackages__"],
)
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
# right now, remove this public rule when no longer needed (it should be
# replaced by TF Lite)
filegroup(
name = "srcs",
srcs = glob(
[
"*.cc",
"*.h",
],
exclude = [
"c_api_experimental.cc",
"c_api_experimental.h",
"*test*",
],
),
visibility = ["//visibility:public"],
)

179
tensorflow/c/eager/c_api.cc Executable file → Normal file
View File

@ -21,11 +21,18 @@ limitations under the License.
#include <string>
#include <vector>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
@ -38,11 +45,15 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@ -63,6 +74,17 @@ using tensorflow::int64;
using tensorflow::string;
namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
if (op->inference_ctx) {
return op->inference_ctx->op_def;
}
const tensorflow::OpDef* op_def;
status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def;
}
bool IsCPU(const tensorflow::Device* d) {
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
@ -77,6 +99,7 @@ string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
#if !defined(IS_MOBILE_PLATFORM)
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
@ -114,11 +137,12 @@ tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, int64 rendezvous_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request,
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) {
const string& remote_worker = remote_workers[i];
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextResponse response;
request.set_rendezvous_id(rendezvous_id);
tensorflow::DeviceNameUtils::ParsedName parsed_name;
@ -132,7 +156,9 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
tensorflow::eager::EagerClient* eager_client;
TF_RETURN_IF_ERROR(
remote_eager_workers->GetClient(remote_worker, &eager_client));
if (eager_client == nullptr) {
return tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
@ -198,6 +224,23 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
remote_workers, grpc_server->master_env()->worker_cache,
&remote_device_mgr));
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendevzous properly between
// Local and Remote context.
tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
grpc_server->channel_cache();
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
@ -207,14 +250,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, rendezvous_id, keep_alive_secs, server_def,
remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
remote_eager_workers.get(), ctx->context->Async(), base_request,
&remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, true));
session_name, server_def, base_request.cluster_device_attributes(),
true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(
@ -226,14 +271,14 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
auto* device_mgr = grpc_server->worker_env()->device_mgr;
ctx->context.InitializeRemote(std::move(server),
std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts,
r, device_mgr, keep_alive_secs);
return tensorflow::Status::OK();
return ctx->context->InitializeRemote(
std::move(server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(remote_device_mgr),
remote_contexts, r, device_mgr, keep_alive_secs,
worker_session->cluster_flr.get());
#undef LOG_AND_RETURN_IF_ERROR
}
#endif // !IS_MOBILE_PLATFORM
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
TFE_TensorHandle* input) {
@ -330,7 +375,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char enable,
TF_Status* status) {
status->status = ctx->context.SetAsyncForThread(enable);
status->status = ctx->context->SetAsyncForThread(enable);
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
@ -349,7 +394,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
return new TFE_Context(opts->session_options.options, opts->policy,
opts->async, device_mgr.release(),
/*device_mgr_owned*/ true, r);
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator());
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
@ -359,23 +405,24 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
if (!status->status.ok()) return nullptr;
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context(opts->session_options.options, opts->policy,
opts->async, device_mgr, /*device_mgr_owned*/ false,
r);
opts->async, device_mgr, /*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator());
}
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context.remote_device_mgr()) {
ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context->remote_device_mgr()) {
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
}
return list;
}
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); }
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
@ -383,6 +430,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
const void* proto,
size_t proto_len,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::errors::Unimplemented(
"TFE_ContextSetServerDef not supported on mobile");
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def;
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
@ -391,11 +442,12 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
}
status->status =
UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
#endif // !IS_MOBILE_PLATFORM
}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context.SetThreadLocalDevicePlacementPolicy(
ctx->context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
}
@ -405,19 +457,19 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
return static_cast<TFE_ContextDevicePlacementPolicy>(
ctx->context.GetDevicePlacementPolicy());
ctx->context->GetDevicePlacementPolicy());
}
void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
status->status = ctx->context.AsyncWait();
status->status = ctx->context->AsyncWait();
}
void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
status->status = ctx->context.GetStatus();
status->status = ctx->context->GetStatus();
}
void TFE_ContextAsyncClearError(TFE_Context* ctx) {
ctx->context.ClearAsyncError();
ctx->context->ClearAsyncError();
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
@ -577,7 +629,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
return new TFE_Op(ctx, name, false, types,
new TFE_OpInferenceContext(op_def));
}
if (!ctx->context.FindFunctionByName(name)) {
if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
@ -807,6 +859,54 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
funcs.get(), num_values));
}
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
"' not found");
return -1;
}
return iter->second.second - iter->second.first;
}
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument(
"Output '", output_name, "' not found");
return -1;
}
return iter->second.second - iter->second.first;
}
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
VLOG(1) << "Calling TFE_Execute() on op " << op;
@ -827,7 +927,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
const char* device_name,
TF_Status* status) {
tensorflow::TensorHandle* handle;
status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
@ -844,26 +944,31 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
status->status = ctx->context.AddFunctionDef(function_def);
status->status = ctx->context->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
status->status = ctx->context.AddFunctionDef(function->fdef);
status->status = ctx->context->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
status->status = ctx->context->RemoveFunction(name);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
return ctx->context.FindFunctionDef(name) != nullptr;
return ctx->context->FindFunctionDef(name) != nullptr;
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(true);
ctx->context.SetShouldStoreStepStats(true);
ctx->context->SetShouldStoreGraphs(true);
ctx->context->SetShouldStoreStepStats(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(false);
ctx->context.SetShouldStoreStepStats(false);
ctx->context->SetShouldStoreGraphs(false);
ctx->context->SetShouldStoreStepStats(false);
}
} // extern "C"
@ -892,9 +997,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
TFE_ContextAsyncWait(ctx, status);
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
ctx->context.ClearRunMetadata();
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
ctx->context->ClearRunMetadata();
}
namespace {
@ -910,9 +1015,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
}
} // namespace
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,

View File

@ -366,6 +366,18 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op,
const TFE_Op** value,
int num_values);
// Returns the length (number of tensors) of the input argument `input_name`
// found in the provided `op`.
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status);
// Returns the length (number of tensors) of the output argument `output_name`
// found in the provided `op`.
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status);
// Execute the operation defined by 'op' and return handles to computed
// tensors in `retvals`.
//
@ -398,6 +410,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx,
TF_Function* function,
TF_Status* status);
// Removes a function from the context. Once removed, you can no longer
// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any
// other function which calls it as an attribute.
TF_CAPI_EXPORT extern void TFE_ContextRemoveFunction(TFE_Context* ctx,
const char* name,
TF_Status* status);
// Checks whether a function is registered under `name`.
TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx,
const char* name);

View File

@ -32,13 +32,13 @@ std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
TF_Status* status) {
std::vector<int64> shape;
int rank = TFE_TensorHandleNumDims(handle, status);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
return shape;
}
shape.reserve(rank);
for (int i = 0; i < rank; ++i) {
shape.push_back(TFE_TensorHandleDim(handle, i, status));
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
return shape;
}
}
@ -53,7 +53,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* handle, TF_Status* status) {
const tensorflow::Tensor* tensor;
status->status = handle->handle->Tensor(&tensor);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
return nullptr;
}
@ -139,7 +139,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
if (!status->status.ok()) {
if (TF_GetCode(status) != TF_OK) {
return nullptr;
}
return new TFE_TensorDebugInfo(dev_dims);

View File

@ -17,6 +17,12 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
@ -39,7 +45,7 @@ void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler,
TF_Buffer* buf, TF_Status* status) {
TFE_ContextAsyncWait(ctx, status);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
string content;
status->status = profiler->profiler->SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());
@ -57,7 +63,7 @@ TFE_ProfilerContext* TFE_NewProfilerContext() {
void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context,
TFE_Context* eager_context) {
profiler_context->profiler_context.eager_context = &eager_context->context;
profiler_context->profiler_context.eager_context = eager_context->context;
}
void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) {
@ -71,11 +77,11 @@ void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) {
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(true);
ctx->context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context.SetShouldStoreGraphs(false);
ctx->context->SetShouldStoreGraphs(false);
}
bool TFE_ProfilerClientStartTracing(const char* service_addr,
@ -92,3 +98,423 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
num_tracing_attempts);
return s.ok();
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);
}
int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
return cell->cell.value();
}
TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringCounter0({name, description});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
TFE_MonitoringCounter0* counter) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell()));
}
TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringCounter1({name, description, label1});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
TFE_MonitoringCounter1* counter, const char* label1) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell(label1)));
}
TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringCounter2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell(label1, label2)));
}
void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
int64_t value) {
cell->cell.Set(value);
}
int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
return cell->cell.value();
}
TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringIntGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
TFE_MonitoringIntGauge0* gauge) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
TFE_MonitoringIntGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringIntGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
const char* value) {
cell->cell.Set({value});
}
const void TFE_MonitoringStringGaugeCellValue(
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
tensorflow::string value = cell->cell.value();
void* data = tensorflow::port::Malloc(value.length());
value.copy(static_cast<char*>(data), value.length(), 0);
buf->data = data;
buf->length = value.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
const char* name, TF_Status* status, const char* description) {
auto* result = new TFE_MonitoringStringGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
TFE_MonitoringStringGauge0* gauge) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
const char* name, TF_Status* status, const char* description,
const char* label1) {
auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
TFE_MonitoringStringGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
const char* name, TF_Status* status, const char* description,
const char* label1, const char* label2) {
auto* result =
new TFE_MonitoringStringGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
bool value) {
cell->cell.Set(value);
}
bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
return cell->cell.value();
}
TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringBoolGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
TFE_MonitoringBoolGauge0* gauge) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
TFE_MonitoringBoolGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringBoolGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
double value) {
cell->cell.Add(value);
}
void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
TF_Buffer* buf) {
string content;
cell->cell.value().SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
buf->data = data;
buf->length = content.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
double growth_factor,
int bucket_count) {
return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
bucket_count);
});
}
void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
delete buckets;
}
TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringSampler0(
{name, buckets->create_buckets(), description});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
TFE_MonitoringSampler0* sampler) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell()));
}
TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description, const char* label1) {
auto* result = new TFE_MonitoringSampler1(
{name, buckets->create_buckets(), description, label1});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
TFE_MonitoringSampler1* sampler, const char* label1) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell(label1)));
}
TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description, const char* label1, const char* label2) {
auto* result = new TFE_MonitoringSampler2(
{name, buckets->create_buckets(), description, label1, label2});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
}

View File

@ -87,6 +87,229 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.
// These APIs de-templated monitoring Counter for swig.
typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell;
// Atomically increments the value of the cell. The value must be non-negative.
TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy(
TFE_MonitoringCounterCell* cell, int64_t value);
// Retrieves the current value of the cell.
TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue(
TFE_MonitoringCounterCell* cell);
// APIs for Counter without label.
typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0;
// Returns a new Counter metric object. The caller should manage lifetime of
// the object. Using duplicate metric name will crash the program with fatal
// error.
TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(
const char* name, TF_Status* status, const char* description);
// Deletes the Counter object.
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0(
TFE_MonitoringCounter0* counter);
// Retrieves the cell from the Counter object. The Counter object will manage
// lifetime of the cell.
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
TFE_MonitoringCounter0* counter);
// APIs for Counter with 1 label.
typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1;
TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(
const char* name, TF_Status* status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1(
TFE_MonitoringCounter1* counter);
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
TFE_MonitoringCounter1* counter, const char* label1);
// APIs for Counter with 2 labels.
typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2;
TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(
const char* name, TF_Status* status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2(
TFE_MonitoringCounter2* counter);
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
TFE_MonitoringCounter2* counter, const char* label1, const char* label2);
// -----------------------------------------------------------------------------
// Monitoring Gauge APIs.
// These APIs de-templated monitoring Gauge for swig.
typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell;
// Atomically set the value of the cell.
TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet(
TFE_MonitoringIntGaugeCell* cell, int64_t value);
// Retrieves the current value of the cell.
TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue(
TFE_MonitoringIntGaugeCell* cell);
// APIs for Int Gauge without label.
typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0(
TFE_MonitoringIntGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge);
// APIs for Int Gauge with 1 label.
typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1(
TFE_MonitoringIntGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge,
const char* label1);
// APIs for Int Gauge with 2 label.
typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2(
TFE_MonitoringIntGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge,
const char* label1, const char* label2);
typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell;
TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet(
TFE_MonitoringStringGaugeCell* cell, const char* value);
// Retrieves the string value and saves it in buffer.
TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue(
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf);
// APIs for String Gauge without label.
typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0(
TFE_MonitoringStringGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge);
// APIs for String Gauge with 1 label.
typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1(
TFE_MonitoringStringGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge,
const char* label1);
// APIs for String Gauge with 2 label.
typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2(
TFE_MonitoringStringGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge,
const char* label1, const char* label2);
typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell;
TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet(
TFE_MonitoringBoolGaugeCell* cell, bool value);
TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue(
TFE_MonitoringBoolGaugeCell* cell);
// APIs for Bool Gauge without label.
typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0(
TFE_MonitoringBoolGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge);
// APIs for Bool Gauge with 1 label.
typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1(
TFE_MonitoringBoolGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge,
const char* label1);
// APIs for Bool Gauge with 2 label.
typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2(
TFE_MonitoringBoolGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge,
const char* label1, const char* label2);
// -----------------------------------------------------------------------------
// Monitoring Sampler APIs.
// These APIs de-templated monitoring Sampler for swig.
typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell;
// Atomically add the value of the cell.
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd(
TFE_MonitoringSamplerCell* cell, double value);
// Retrieves the current value of the cell. The return value is a HistogramProto
// saved in buffer.
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue(
TFE_MonitoringSamplerCell* cell, TF_Buffer* buf);
// APIs for sampler buckets
typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets;
TF_CAPI_EXPORT extern TFE_MonitoringBuckets*
TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor,
int bucket_count);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets(
TFE_MonitoringBuckets* buckets);
// APIs for Sampler without label.
typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0;
TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0(
TFE_MonitoringSampler0* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
TFE_MonitoringSampler0* sampler);
// APIs for Sampler with 1 label.
typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1;
TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description, const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1(
TFE_MonitoringSampler1* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
TFE_MonitoringSampler1* sampler, const char* label1);
// APIs for Sampler with 2 label.
typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2;
TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description, const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
TFE_MonitoringSampler2* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -16,14 +16,16 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include <string.h>
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/profiler/trace_events.pb.h"
#include "tensorflow/core/protobuf/trace_events.pb.h"
using tensorflow::string;
@ -79,11 +81,15 @@ void ExecuteWithProfiling(bool async) {
profiler_result->length}));
string profile_proto_str = profile_proto.DebugString();
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
// device name with "stream:all" is collected by Device Tracer.
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
// TODO(fishx): move following check out from this if statement.
// This is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
}
EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:CPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
TF_DeleteBuffer(profiler_result);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
@ -125,5 +131,165 @@ TEST(CAPI, MultipleProfilerSession) {
TFE_DeleteProfilerContext(profiler_context);
}
TEST(CAPI, MonitoringCounter0) {
TF_Status* status = TF_NewStatus();
auto* counter =
TFE_MonitoringNewCounter0("test/counter", status, "description");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
auto* cell = TFE_MonitoringGetCellCounter0(counter);
TFE_MonitoringCounterCellIncrementBy(cell, 1);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/counter",
metrics->point_set_map.at("test/counter")->metric_name);
EXPECT_EQ(
1, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
TFE_MonitoringCounterCellIncrementBy(cell, 5);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 6);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(
6, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
TFE_MonitoringDeleteCounter0(counter);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(metrics->point_set_map.end(),
metrics->point_set_map.find("test/counter"));
}
TEST(CAPI, MonitoringCounterMultiple) {
TF_Status* status = TF_NewStatus();
auto* counter1 = TFE_MonitoringNewCounter1("test/counter1", status,
"description", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellCounter1(counter1, "test");
TFE_MonitoringCounterCellIncrementBy(cell1, 1);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell1), 1);
auto* counter2 = TFE_MonitoringNewCounter2("test/counter2", status,
"description", "label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
auto* cell2 = TFE_MonitoringGetCellCounter2(counter2, "foo", "bar");
TFE_MonitoringCounterCellIncrementBy(cell2, 2);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell2), 2);
TFE_MonitoringDeleteCounter1(counter1);
TFE_MonitoringDeleteCounter2(counter2);
}
TEST(CAPI, MonitoringGauge0) {
TF_Status* status = TF_NewStatus();
auto* gauge = TFE_MonitoringNewIntGauge0("test/gauge", status, "test");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell = TFE_MonitoringGetCellIntGauge0(gauge);
TFE_MonitoringIntGaugeCellSet(cell, 1);
EXPECT_EQ(TFE_MonitoringIntGaugeCellValue(cell), 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
EXPECT_EQ(1,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TFE_MonitoringIntGaugeCellSet(cell, 5);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(5,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringMultipleGauge) {
TF_Status* status = TF_NewStatus();
auto* gauge1 =
TFE_MonitoringNewBoolGauge1("test/gauge1", status, "test", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellBoolGauge1(gauge1, "foo");
TFE_MonitoringBoolGaugeCellSet(cell1, true);
EXPECT_TRUE(TFE_MonitoringBoolGaugeCellValue(cell1));
auto* gauge2 = TFE_MonitoringNewStringGauge2("test/gauge2", status, "test",
"label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell2 = TFE_MonitoringGetCellStringGauge2(gauge2, "foo", "bar");
TFE_MonitoringStringGaugeCellSet(cell2, "str");
auto* buf = new TF_Buffer;
TFE_MonitoringStringGaugeCellValue(cell2, buf);
string data(static_cast<const char*>(buf->data), buf->length);
delete buf;
EXPECT_EQ(data, "str");
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringSampler0) {
TF_Status* status = TF_NewStatus();
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
auto* sampler =
TFE_MonitoringNewSampler0("test/sampler", buckets, status, "test");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell = TFE_MonitoringGetCellSampler0(sampler);
TFE_MonitoringSamplerCellAdd(cell, 1.0);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/sampler",
metrics->point_set_map.at("test/sampler")->metric_name);
EXPECT_EQ(1.0, metrics->point_set_map.at("test/sampler")
->points.at(0)
->histogram_value.sum());
TFE_MonitoringSamplerCellAdd(cell, 5.0);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler")
->points.at(0)
->histogram_value.sum());
TFE_MonitoringDeleteBuckets(buckets);
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringMultipleSampler) {
TF_Status* status = TF_NewStatus();
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
auto* sampler1 = TFE_MonitoringNewSampler1("test/sampler1", buckets, status,
"test", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellSampler1(sampler1, "foo");
TFE_MonitoringSamplerCellAdd(cell1, 1.0);
TFE_MonitoringSamplerCellAdd(cell1, 2.0);
TF_Buffer* result1 = TF_NewBuffer();
TFE_MonitoringSamplerCellValue(cell1, result1);
tensorflow::HistogramProto hitogram1;
EXPECT_TRUE(hitogram1.ParseFromString(
{reinterpret_cast<const char*>(result1->data), result1->length}));
EXPECT_EQ(hitogram1.sum(), 3.0);
delete result1;
auto* sampler2 = TFE_MonitoringNewSampler2("test/sampler2", buckets, status,
"test", "label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell2 = TFE_MonitoringGetCellSampler2(sampler2, "foo", "bar");
TFE_MonitoringSamplerCellAdd(cell2, 2.0);
TFE_MonitoringSamplerCellAdd(cell2, 3.0);
TF_Buffer* result2 = TF_NewBuffer();
TFE_MonitoringSamplerCellValue(cell2, result2);
tensorflow::HistogramProto hitogram2;
EXPECT_TRUE(hitogram2.ParseFromString(
{reinterpret_cast<const char*>(result2->data), result2->length}));
EXPECT_EQ(hitogram2.sum(), 5.0);
delete result2;
TFE_MonitoringDeleteBuckets(buckets);
TF_DeleteStatus(status);
}
} // namespace
} // namespace tensorflow

View File

@ -15,8 +15,6 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#include "tensorflow/c/eager/c_api.h"
#include <algorithm>
#include <cstddef>
#include <map>
@ -28,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -37,19 +36,14 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
@ -66,13 +60,18 @@ struct TFE_Context {
TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_policy, bool async,
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous)
: context(opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
async, device_mgr, device_mgr_owned, rendezvous) {}
tensorflow::Rendezvous* rendezvous,
const tensorflow::CustomKernelCreator* custom_kernel_creator)
: context(new tensorflow::EagerContext(
opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
async, device_mgr, device_mgr_owned, rendezvous,
custom_kernel_creator)) {}
tensorflow::EagerContext context;
~TFE_Context() { context->Unref(); }
tensorflow::EagerContext* context;
};
struct TFE_TensorHandle {
@ -112,7 +111,7 @@ struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
TFE_OpInferenceContext* inference_ctx)
: operation(&ctx->context, op, is_function, t),
: operation(ctx->context, op, is_function, t),
inference_ctx(inference_ctx) {}
tensorflow::EagerOperation operation;
@ -131,6 +130,124 @@ struct TFE_Profiler {
std::unique_ptr<tensorflow::ProfilerSession> profiler;
};
struct TFE_MonitoringCounterCell {
tensorflow::monitoring::CounterCell cell;
};
template <int NumLabels>
struct TFE_MonitoringCounter {
template <typename... LabelDesc>
TFE_MonitoringCounter(const char* name, const char* description,
LabelDesc&&... label) {
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
};
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringIntGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
};
struct TFE_MonitoringStringGaugeCell {
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
};
struct TFE_MonitoringBoolGaugeCell {
tensorflow::monitoring::GaugeCell<bool> cell;
};
template <typename ValueType, int NumLabels>
struct TFE_MonitoringGauge {
template <typename... LabelDesc>
TFE_MonitoringGauge(const char* name, const char* description,
LabelDesc&&... label) {
gauge = absl::WrapUnique(
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
};
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
using TFE_MonitoringGauge::TFE_MonitoringGauge;
};
struct TFE_MonitoringBuckets {
TFE_MonitoringBuckets(
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
fn) {
create_buckets = fn;
}
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
create_buckets;
};
struct TFE_MonitoringSamplerCell {
tensorflow::monitoring::SamplerCell cell;
};
template <int NumLabels>
struct TFE_MonitoringSampler {
template <typename... LabelDesc>
TFE_MonitoringSampler(
const char* name,
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
const char* description, LabelDesc&&... label) {
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
{name, description, label...}, std::move(buckets)));
}
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
};
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
using TFE_MonitoringSampler::TFE_MonitoringSampler;
};
namespace tensorflow {
// Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,

View File

@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include <string.h>
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
@ -297,6 +298,61 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
}
void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContext(ctx);
// Delete tensors after context is deleted.
TFE_DeleteTensorHandle(h0_task1);
TF_DeleteStatus(status);
// TODO(nareshmodi): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) {
TestRemoteExecuteDeleteTensorAfterContext(false);
}
TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) {
TestRemoteExecuteDeleteTensorAfterContext(true);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -1225,6 +1281,8 @@ TEST(CAPI, Function_ident_CPU) {
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
TFE_ContextRemoveFunction(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
@ -1295,6 +1353,8 @@ TEST(CAPI, Function_ident_XLA_CPU) {
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
TFE_ContextRemoveFunction(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
@ -1371,6 +1431,8 @@ void FunctionDefAndExecute(bool async) {
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
@ -1412,6 +1474,8 @@ void BM_ExecuteFunction(int iters, int async) {
tensorflow::testing::StopTiming();
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval[0]);
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
@ -1781,4 +1845,80 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_DeleteTensorHandle(dim);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Try to retrieve lengths before building the attributes (should fail)
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Try to retrieve lengths before executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[2] = {nullptr};
int num_retvals = 2;
TFE_Execute(identityOp, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Try to retrieve lengths after executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteOp(identityOp);
TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(retvals[1]);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "cheese", status));
CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "cheese", status));
CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteOp(identityOp);
TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2);
TFE_DeleteContext(ctx);
}
} // namespace

View File

@ -47,11 +47,12 @@ struct OpTapeEntry {
// Map from tensor_id to internally-defined operation-id of the operation which
// produced this tensor. A value of -1 means that the tensor was directly
// watched and not the result of any operation in the tape.
using TensorTape = gtl::FlatMap<int64, int64>;
using TensorTape = std::unordered_map<int64, int64>;
// Map from operation-id to tape entry.
template <typename BackwardFunction, typename TapeTensor>
using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
using OpTape =
std::unordered_map<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
// Operations the tape needs to perform on tensors to do backpropagation. Named
// "vspace" because a subset of these are related to a vector space, such as
@ -94,6 +95,7 @@ class VSpace {
// Calls the passed-in backward function.
virtual Status CallBackwardFunction(
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) const = 0;
@ -143,7 +145,7 @@ class GradientTape {
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result);
@ -156,7 +158,7 @@ class GradientTape {
// Map from tensor id to number of remaining usages (i.e. how many entries in
// the tape refer to it); to aid in tape garbage collection.
gtl::FlatMap<int64, int64> tensor_usage_;
std::unordered_map<int64, int64> tensor_usage_;
// If false, all activations are deleted in the first call to ComputeGradient.
// Else, only when this is destructed.
@ -307,11 +309,11 @@ struct BackpropInitialState {
// Map from tensor ID to how many references still exist for this tensor in
// the tape.
gtl::FlatMap<int64, int64> tensor_usage_counts;
std::unordered_map<int64, int64> tensor_usage_counts;
// Maps from op ID to how many output tensors of this op still need to have
// their gradients computed.
gtl::FlatMap<int64, int64> op_missing_tensor;
std::unordered_map<int64, int64> op_missing_tensor;
};
// If `persistent_tape` is true, op_tape is not changed and none of the
@ -323,7 +325,7 @@ template <typename BackwardFunction, typename TapeTensor>
BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
OpTape<BackwardFunction, TapeTensor>* op_tape,
const gtl::FlatSet<int64>& sources_set, bool persistent_tape) {
const std::unordered_set<int64>& sources_set, bool persistent_tape) {
std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size());
for (auto t : target) {
@ -383,7 +385,7 @@ BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
template <typename BackwardFunction, typename TapeTensor>
std::vector<int64> InitialStack(
const OpTape<BackwardFunction, TapeTensor>& op_tape,
const gtl::FlatMap<int64, int64>& op_missing_tensor) {
const std::unordered_map<int64, int64>& op_missing_tensor) {
std::vector<int64> result;
for (auto& op_entry : op_tape) {
if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
@ -397,10 +399,10 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Status InitialGradients(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
const OpTape<BackwardFunction, TapeTensor>& op_tape,
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
std::unordered_map<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (output_gradients.empty() || output_gradients[i] == nullptr) {
@ -454,12 +456,14 @@ Status InitialGradients(
// corresponding to index 0 is used, and the gradient values at indices 1-4 are
// ignored (and hence can be None). The backprop algorithm can then leverage
// this by not constructing zeros to pass for those indices.
gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({
{"SoftmaxCrossEntropyWithLogits", {1}},
{"SparseSoftmaxCrossEntropyWithLogits", {1}},
{"FusedBatchNorm", {1, 2, 3, 4}},
});
std::unordered_map<string, std::unordered_set<int>>*
FunctionsAcceptingNoneForIndicesMap() {
static auto* const m =
new std::unordered_map<string, std::unordered_set<int>>({
{"SoftmaxCrossEntropyWithLogits", {1}},
{"SparseSoftmaxCrossEntropyWithLogits", {1}},
{"FusedBatchNorm", {1, 2, 3, 4}},
});
return m;
}
@ -476,16 +480,16 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
std::unordered_map<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids,
sources_that_are_targets, output_gradients,
tensor_tape_, state.op_tape, &gradients);
@ -501,7 +505,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
cleanup();
return s;
}
gtl::FlatMap<int64, int64> gradients_size;
std::unordered_map<int64, int64> gradients_size;
// TODO(apassos) multiple threads could be dequeuing from op_stack at the same
// time, for better CPU backprop performance.
VLOG(1) << "Initial stack:";
@ -524,7 +529,17 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
state.op_tape.erase(op_it);
std::vector<Gradient*> out_gradients;
out_gradients.reserve(trace.output_tensor_info.size());
std::vector<int64> unneeded_gradients;
for (int i = 0; i < trace.input_tensor_id.size(); i++) {
const auto& in_tensor_id = trace.input_tensor_id[i];
if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() &&
sources_set.find(in_tensor_id) == sources_set.end()) {
unneeded_gradients.push_back(i);
}
}
bool any_gradient_nonzero = false;
std::vector<int> zero_indices;
for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
@ -535,7 +550,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i]));
out_gradients.push_back(nullptr);
zero_indices.push_back(i);
}
} else {
any_gradient_nonzero = true;
@ -557,8 +573,13 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
}
std::vector<Gradient*> in_gradients;
if (any_gradient_nonzero) {
Status s = vspace.CallBackwardFunction(trace.backward_function,
out_gradients, &in_gradients);
for (const auto i : zero_indices) {
out_gradients[i] = vspace.Zeros(trace.output_tensor_info[i]);
}
Status s;
s = vspace.CallBackwardFunction(trace.backward_function,
unneeded_gradients, out_gradients,
&in_gradients);
if (!persistent_) {
trace.backward_function_deleter(trace.backward_function);
}
@ -634,14 +655,16 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
VLOG(1) << "Op " << op_id << " missing " << missing_it->second
<< " output gradients";
if (missing_it->second == 0) {
op_stack.push_back(op_id);
op_stack.insert(op_stack.begin(), op_id);
}
}
}
}
CHECK(state.op_tape.empty());
if (!state.op_tape.empty()) {
return tensorflow::errors::Internal("Invalid tape state.");
}
result->reserve(source_tensor_ids.size());
gtl::FlatSet<int64> used_gradient_ids(source_tensor_ids.size());
std::unordered_set<int64> used_gradient_ids(source_tensor_ids.size());
for (auto is : source_tensor_ids) {
auto grad_it = gradients.find(is);
if (grad_it == gradients.end()) {

View File

@ -0,0 +1,122 @@
# Description:
# Experimental C APIs for TensorFlow.
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"tf_cuda_library",
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
tf_cuda_library(
name = "rendezvous_internal",
srcs = [
"rendezvous.cc",
],
hdrs = [
"rendezvous.h",
"rendezvous_internal.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
"//tensorflow/c:c_api_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
tf_cuda_library(
name = "rendezvous",
hdrs = [
"rendezvous.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":rendezvous_internal",
"//tensorflow/c:c_api",
],
)
tf_cuda_library(
name = "network_internal",
srcs = [
"network.cc",
],
hdrs = [
"network.h",
"network_internal.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
":rendezvous_internal",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
tf_cuda_library(
name = "network",
hdrs = [
"network.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":network_internal",
":rendezvous",
"//tensorflow/c:c_api",
],
)
# -----------------------------------------------------------------------------
# Tests
tf_cuda_cc_test(
name = "network_test",
size = "medium",
srcs = ["network_test.cc"],
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":network",
":network_internal",
":rendezvous",
":rendezvous_internal",
"//tensorflow/c:c_api",
"//tensorflow/c:env",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_session",
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)

View File

@ -0,0 +1,166 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/network.h"
#include <memory>
#include <string>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/network_internal.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
using tensorflow::ServerFactory;
namespace tensorflow {
/* static */ Status CGrpcServer::Create(
const ServerDef& server_def,
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder,
std::unique_ptr<ServerInterface>* out_server) {
auto* grpc_server = new CGrpcServer(server_def, start_function, stop_function,
join_function, delete_function);
GrpcServerOptions options;
options.rendezvous_mgr_func = [rendezvous_builder](const WorkerEnv* env) {
return new CRendezvousMgr(env, rendezvous_builder);
};
TF_RETURN_IF_ERROR(grpc_server->Init(options));
TF_Status* tf_status = TF_NewStatus();
grpc_server->SetContext(init_function(
reinterpret_cast<const TF_GrpcServer*>(grpc_server), tf_status));
TF_RETURN_IF_ERROR(tf_status->status);
TF_DeleteStatus(tf_status);
out_server->reset(grpc_server);
return Status::OK();
}
Status CGrpcServer::Start() {
Status status = GrpcServer::Start();
TF_Status* tf_status = TF_NewStatus();
(*start_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Stop() {
Status status = GrpcServer::Stop();
TF_Status* tf_status = TF_NewStatus();
(*stop_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Join() {
Status status = GrpcServer::Join();
TF_Status* tf_status = TF_NewStatus();
(*join_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
namespace {
// Factory that creates CGrpcServer instances.
class CServerFactory : public ServerFactory {
public:
CServerFactory(bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*,
TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder)
: accept_function_(accept_function),
init_function_(init_function),
start_function_(start_function),
stop_function_(stop_function),
join_function_(join_function),
delete_function_(delete_function),
rendezvous_builder_(rendezvous_builder) {}
Status NewServer(const ServerDef& server_def,
std::unique_ptr<ServerInterface>* out_server) override {
TF_RETURN_IF_ERROR(CGrpcServer::Create(
server_def, init_function_, start_function_, stop_function_,
join_function_, delete_function_, rendezvous_builder_, out_server));
return Status::OK();
}
// Returns true if and only if this factory can create a server
// based on the given `server_def`.
bool AcceptsOptions(const ServerDef& server_def) override {
return (*accept_function_)(server_def.protocol().c_str());
}
private:
bool (*accept_function_)(const char* protocol);
void* (*init_function_)(const TF_GrpcServer*, TF_Status*);
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*delete_function_)(void*);
TF_RemoteRendezvousBuilder* rendezvous_builder_;
};
} // namespace
} // namespace tensorflow
// Server factory representation to use in C API.
// Holds CServerFactory pointer.
struct TF_GrpcServerFactory {
::tensorflow::CServerFactory* factory;
};
TF_GrpcServerFactory* TF_NewGrpcServerFactory(
bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder) {
TF_GrpcServerFactory* server_factory = new TF_GrpcServerFactory;
server_factory->factory = new ::tensorflow::CServerFactory(
accept_function, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
return server_factory;
}
void TF_DeleteGrpcServerFactory(TF_GrpcServerFactory* server_factory) {
DCHECK_NE(server_factory, nullptr);
delete server_factory;
}
void TF_RegisterGrpcServerFactory(const char* server_type,
TF_GrpcServerFactory* server_factory) {
ServerFactory::Register(server_type, server_factory->factory);
}

View File

@ -0,0 +1,97 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/rendezvous.h"
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// C API for TensorFlow Networking.
// NOTE: This API is unstable and almost certainly will change in the near
// future.
//
// Users wishing to register a custom GrpcServer should call
// TF_NewServerFactory and then TF_RegisterGrpcServerFactory.
//
// Example:
// ```c++
// auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
// rendezvous_init_function,
// receive_from_remote_async_function,
// rendezvous_delete_function);
//
// TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
// accept_function,
// init_function,
// start_function,
// stop_function,
// join_function,
// delete_function,
// rendezvous_builder);
// TF_RegisterGrpcServerFactory("customfactory", factory);
// ...
// TF_DeleteGrpcServerFactory(factory);
// ```
typedef struct TF_GrpcServerFactory TF_GrpcServerFactory;
typedef struct TF_GrpcServerOptions TF_GrpcServerOptions;
typedef struct TF_GrpcServer TF_GrpcServer;
typedef struct TF_ServerContext {
TF_GrpcServer* const server;
void* context;
} TF_ServerContext;
// Creates a new TF_GrpcServerFactory instance. Caller takes ownership
// of TF_GrpcServerFactory instance and should deallocate it by calling
// TF_GrpcDeleteServerFactory.
// accept_function should return true if this ServerFactory can create
// server instances for the given protocol name (for e.g. grpc+verbs).
// GRPC servers created by this factory will call provided
// init_function, start_function, stop_function, join_function and
// delete_function.
//
// Note that clean shutdown is currently not implemented for GrpcServer.
// So, stop_function will never be called now but may be in the future
// when stop mechanism is supported.
TF_CAPI_EXPORT extern TF_GrpcServerFactory* TF_NewGrpcServerFactory(
bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder);
// Deletes TF_GrpcServerFactory instances.
// Note that this function only deletes TF_GrpcServerFactory wrapper.
// Actual underlying server factory would not be deleted and will
// remain registered.
TF_CAPI_EXPORT extern void TF_DeleteGrpcServerFactory(
TF_GrpcServerFactory* server_factory);
// Registers provided server_factory for the given server_type.
// server_type must be unique to the server factory.
TF_CAPI_EXPORT extern void TF_RegisterGrpcServerFactory(
const char* server_type, TF_GrpcServerFactory* server_factory);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/network.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
// GrpcServer implementation that forwards calls to callbacks.
class CGrpcServer : public GrpcServer {
protected:
CGrpcServer(const ServerDef& server_def,
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*))
: GrpcServer(server_def, ::tensorflow::Env::Default()),
start_function_(start_function),
stop_function_(stop_function),
join_function_(join_function),
delete_function_(delete_function),
context_(nullptr) {}
public:
static Status Create(
const ServerDef& server_def,
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder,
std::unique_ptr<ServerInterface>* out_server);
Status Start() override;
Status Stop() override;
Status Join() override;
~CGrpcServer() override { delete_function_(context_); }
protected:
void SetContext(void* context) { context_ = context; }
private:
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*delete_function_)(void*);
void* context_;
friend class NetworksTest;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_

View File

@ -0,0 +1,256 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/network.h"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/network_internal.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
bool accept_functionA(const char* protocol_name) {
return strcmp(protocol_name, "grpc+A") == 0;
}
bool accept_functionB(const char* protocol_name) {
return strcmp(protocol_name, "grpc+B") == 0;
}
struct SomeServerData {
bool server_started = false;
};
struct SomeRendezvousData {
int test = 0;
};
void* init_function(const TF_GrpcServer* server, TF_Status* status) {
SomeServerData* server_data = new SomeServerData();
TF_SetStatus(status, TF_OK, "");
return server_data;
}
void start_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
auto* server_data = static_cast<SomeServerData*>(context);
server_data->server_started = true;
TF_SetStatus(status, TF_OK, "");
}
void stop_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void join_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void delete_function(void* context) {
auto* server_data = static_cast<SomeServerData*>(context);
delete server_data;
}
void* rendezvous_init_function(void* server_context) {
return new SomeRendezvousData();
}
void Deallocator(void* data, size_t, void* arg) {
tensorflow::cpu_allocator()->DeallocateRaw(data);
*reinterpret_cast<bool*>(arg) = true;
}
void receive_from_remote_async_function(TF_ParsedKey* key,
TF_RendezvousArgs* args,
TF_RendezvousDoneCallback* callback,
void* context) {
// Create dummy tensor
const int num_bytes = 6 * sizeof(float);
float* values =
reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
EIGEN_MAX_ALIGN_BYTES, num_bytes));
int64_t dims[] = {2, 3};
bool deallocator_called = false;
auto* tensor = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
&Deallocator, &deallocator_called);
callback->tensor = tensor;
auto* tf_status = TF_NewStatus();
TF_SetStatus(tf_status, TF_OK, "");
callback->status = tf_status;
TF_RendezvousDone(callback);
TF_DeleteStatus(tf_status);
TF_DeleteTensor(tensor);
}
void rendezvous_delete_function(void* context) {
auto* rendezvous_data = static_cast<SomeRendezvousData*>(context);
delete rendezvous_data;
}
tensorflow::ServerDef GetServerDef(const string& protocol,
const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol(protocol);
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
class NetworksTest : public ::testing::Test {
public:
~NetworksTest() override {}
SomeServerData* GetServerData(CGrpcServer* server) {
EXPECT_NE(server->context_, nullptr);
return static_cast<SomeServerData*>(server->context_);
}
};
Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
const string& receiver, const string& name) {
Rendezvous::ParsedKey result;
CHECK(
Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
name, FrameAndIter(0, 0)),
&result)
.ok());
return result;
}
void InitializeRendezvous(GrpcServer* grpc_server, ServerDef* server_def,
RemoteRendezvous* remote_rendezvous) {
int rendezvous_id = 0;
auto session_name = tensorflow::strings::StrCat("test_", rendezvous_id);
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, *server_def, true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
TF_EXPECT_OK(remote_rendezvous->Initialize(worker_session.get()));
}
TEST_F(NetworksTest, TestStartServer) {
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
rendezvous_init_function, receive_from_remote_async_function,
rendezvous_delete_function);
TF_Status* tf_status = TF_NewStatus();
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
accept_functionA, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
TF_RegisterGrpcServerFactory("testfactoryA", factory);
ServerDef server_def = GetServerDef("grpc+A", "localhost", 1);
std::unique_ptr<ServerInterface> server;
TF_EXPECT_OK(NewServer(server_def, &server));
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
auto* server_data = GetServerData(grpc_server);
ASSERT_FALSE(server_data->server_started);
TF_EXPECT_OK(server->Start());
ASSERT_TRUE(server_data->server_started);
TF_DeleteStatus(tf_status);
TF_DeleteGrpcServerFactory(factory);
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
// TODO(annarev): find a clean way to shutdown server.
server.release();
}
TEST_F(NetworksTest, TestReceiveData) {
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
rendezvous_init_function, receive_from_remote_async_function,
rendezvous_delete_function);
TF_Status* tf_status = TF_NewStatus();
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
accept_functionB, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
TF_RegisterGrpcServerFactory("testfactoryB", factory);
ServerDef server_def = GetServerDef("grpc+B", "localhost", 1);
std::unique_ptr<ServerInterface> server;
TF_EXPECT_OK(NewServer(server_def, &server));
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
TF_EXPECT_OK(server->Start());
auto* rendezvous_mgr = grpc_server->worker_env()->rendezvous_mgr;
auto* remote_rendezvous = rendezvous_mgr->Find(0);
auto key = Key("/job:localhost/replica:1/task:2/device:CPU:0", 1,
"/job:localhost/replica:0/task:0/device:CPU:0", "test");
Rendezvous::Args args;
bool done_callback_called = false;
auto* done_callback_called_ptr = &done_callback_called;
absl::Notification notification;
auto* notification_ptr = &notification;
InitializeRendezvous(grpc_server, &server_def, remote_rendezvous);
remote_rendezvous->RecvAsync(
key, args,
[done_callback_called_ptr, notification_ptr](
const Status&, const Rendezvous::Args&, const Rendezvous::Args&,
const Tensor&, const bool) mutable {
*done_callback_called_ptr = true;
notification_ptr->Notify();
});
notification.WaitForNotificationWithTimeout(absl::Seconds(10));
ASSERT_EQ(done_callback_called, true);
TF_DeleteStatus(tf_status);
TF_DeleteGrpcServerFactory(factory);
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
// Server doesn't have a clean shutdown.
server.release();
}
} // namespace tensorflow

View File

@ -0,0 +1,124 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/rendezvous.h"
#include <functional>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
void (*receive_from_remote_async_function)(
TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context),
void* server_context)
: BaseRemoteRendezvous(env, step_id),
receive_from_remote_async_function_(receive_from_remote_async_function),
delete_function_(delete_function),
context_(nullptr) {}
void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) {
TF_ParsedKey key;
key.src_device = parsed.src_device.data();
key.src_device_len = parsed.src_device.size();
key.dst_device = parsed.dst_device.data();
key.dst_device_len = parsed.dst_device.size();
key.full_key = parsed.FullKey().data();
key.full_key_len = parsed.FullKey().size();
TF_DeviceContext* device_context = new TF_DeviceContext();
device_context->context = args.device_context;
TF_AllocatorAttributes* alloc_attrs = new TF_AllocatorAttributes();
alloc_attrs->value = args.alloc_attrs.value;
alloc_attrs->scope_id = args.alloc_attrs.scope_id;
alloc_attrs->on_host = args.alloc_attrs.on_host();
alloc_attrs->nic_compatible = args.alloc_attrs.nic_compatible();
TF_RendezvousArgs* cargs = new TF_RendezvousArgs();
cargs->device_context = device_context;
cargs->alloc_attrs = alloc_attrs;
TF_RendezvousDoneCallback* done_callback = new TF_RendezvousDoneCallback();
done_callback->done_callback = done;
done_callback->recv_args = cargs;
receive_from_remote_async_function_(&key, cargs, done_callback, context_);
}
CRemoteRendezvous::~CRemoteRendezvous() { delete_function_(context_); }
} // namespace tensorflow
TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
void* (*init_function)(void* server_context),
void (*receive_from_remote_async_function)(TF_ParsedKey*,
TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context)) {
TF_RemoteRendezvousBuilder* builder = new TF_RemoteRendezvousBuilder();
builder->init_function = init_function;
builder->delete_function = delete_function;
builder->receive_from_remote_async_function =
receive_from_remote_async_function;
return builder;
}
void TF_DeleteRemoteRendezvousBuilder(
TF_RemoteRendezvousBuilder* rendezvous_builder) {
DCHECK_NE(rendezvous_builder, nullptr);
delete rendezvous_builder;
}
TF_CAPI_EXPORT extern void TF_RendezvousDone(
TF_RendezvousDoneCallback* callback) {
DCHECK_NE(callback, nullptr);
::tensorflow::Tensor tensor;
TF_CHECK_OK(TF_TensorToTensor(callback->tensor, &tensor));
::tensorflow::Rendezvous::Args recv_args;
recv_args.alloc_attrs.value = callback->recv_args->alloc_attrs->value;
recv_args.alloc_attrs.scope_id = callback->recv_args->alloc_attrs->scope_id;
recv_args.device_context = callback->recv_args->device_context->context;
::tensorflow::Rendezvous::Args sent_args;
callback->done_callback(callback->status->status, sent_args, recv_args,
tensor, callback->dead);
if (callback->recv_args) {
DCHECK_NE(callback->recv_args, nullptr);
DCHECK_NE(callback->recv_args->alloc_attrs, nullptr);
DCHECK_NE(callback->recv_args->device_context, nullptr);
delete callback->recv_args->alloc_attrs;
delete callback->recv_args->device_context;
delete callback->recv_args;
}
delete callback;
callback = nullptr;
}

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
#include "tensorflow/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// C API for Rendezvous.
// NOTE: This API is unstable and almost certainly will change in the near
// future.
//
// Custom rendezvous allows for custom implementations of Recv call.
//
// Users wishing to create custom rendezvous objects should call
// TF_NewRemoteRendezvousBuilder and pass returned TF_RemoteRendezvousBuilder
// to to TF_NewServerFactory.
typedef struct TF_RemoteRendezvousBuilder TF_RemoteRendezvousBuilder;
typedef struct TF_ParsedKey TF_ParsedKey;
typedef struct TF_RendezvousArgs TF_RendezvousArgs;
typedef struct TF_RendezvousDoneCallback TF_RendezvousDoneCallback;
// Creates a new TF_RemoteRendezvousBuilder instance.
// Rendezvous instances will forward calls to init_function,
// receive_from_remote_async_function and delete_function passed here.
//
// Note that receive_from_remote_async_function implementation must call
// TF_Done with the TF_DoneCallback passed as an argument.
TF_CAPI_EXPORT extern TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
void* (*init_function)(void* server_context),
void (*receive_from_remote_async_function)(TF_ParsedKey*,
TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context));
// Deletes TF_RemoteRendezvousBuilder instances.
TF_CAPI_EXPORT extern void TF_DeleteRemoteRendezvousBuilder(
TF_RemoteRendezvousBuilder* rendezvous_builder);
// Calls TF_DoneCallback and destroys callback instance and
// TF_DoneCallback members except `tensor` and `status`. Caller is
// responsible for deleting `tensor` and `status` after TF_Done returns.
TF_CAPI_EXPORT extern void TF_RendezvousDone(
TF_RendezvousDoneCallback* callback);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_

View File

@ -0,0 +1,135 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
#include <stddef.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/platform/macros.h"
struct TF_ParsedKey {
// char* members might not be null-terminated.
const char* src_device;
size_t src_device_len;
const char* dst_device;
size_t dst_device_len;
const char* full_key;
size_t full_key_len;
};
struct TF_AllocatorAttributes {
bool on_host;
bool nic_compatible;
// NOTE: The upper 8 bits of the value are reserved for
// device-specific uses. Implementors of a device can interpret these
// upper 8 bits in device-specific ways, and ops implemented for those
// devices are responsible for setting those 8 bits appropriately.
tensorflow::uint32 value = 0;
// EXPERIMENTAL: If this is greater than zero, then allocation is delegated to
// a named special-purpose allocator on the same device.
tensorflow::int32 scope_id = 0;
};
struct TF_DeviceContext {
::tensorflow::DeviceContext* context;
};
struct TF_RendezvousArgs {
const TF_DeviceContext* device_context;
const TF_AllocatorAttributes* alloc_attrs;
};
struct TF_RendezvousDoneCallback {
::tensorflow::Rendezvous::DoneCallback done_callback;
// TODO(annarev): figure out if we should also support sent_args.
const TF_RendezvousArgs* recv_args;
TF_Tensor* tensor = nullptr;
TF_Status* status;
bool dead;
};
struct TF_RemoteRendezvousBuilder {
void* (*init_function)(void* server_context);
void (*receive_from_remote_async_function)(TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context);
void (*delete_function)(void* context);
void* server_context;
};
namespace tensorflow {
class CRemoteRendezvous : public BaseRemoteRendezvous {
public:
CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
void (*receive_from_remote_async_function)(
TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*, void* context),
void (*delete_function)(void* context),
void* server_context);
void SetContext(void* context) { context_ = context; }
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) override;
private:
~CRemoteRendezvous() override;
void (*receive_from_remote_async_function_)(TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context);
void (*delete_function_)(void* context);
void* context_;
TF_DISALLOW_COPY_AND_ASSIGN(CRemoteRendezvous);
};
class CRendezvousMgr : public BaseRendezvousMgr {
public:
CRendezvousMgr(const WorkerEnv* env,
const TF_RemoteRendezvousBuilder* rendezvous_builder)
: BaseRendezvousMgr(env), rendezvous_builder_(rendezvous_builder) {}
protected:
BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env) override {
auto* rendezvous = new CRemoteRendezvous(
worker_env, step_id,
rendezvous_builder_->receive_from_remote_async_function,
rendezvous_builder_->delete_function,
rendezvous_builder_->server_context);
rendezvous->SetContext(rendezvous_builder_->init_function(
rendezvous_builder_->server_context));
return rendezvous;
}
private:
const TF_RemoteRendezvousBuilder* rendezvous_builder_;
TF_DISALLOW_COPY_AND_ASSIGN(CRendezvousMgr);
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_

View File

@ -0,0 +1,44 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_kernel_library",
)
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
tf_kernel_library(
name = "bitcast_op",
prefix = "bitcast_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
],
)
tf_cc_test(
name = "bitcast_op_test",
srcs = ["bitcast_op_test.cc"],
deps = [
":bitcast_op",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Changes to the Android srcs here should be replicated in
# tensorflow/contrib/makefile/tf_op_files.txt
# LINT.IfChange
filegroup(
name = "android_all_ops",
srcs = [
"bitcast_op.cc",
],
)
# LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt)

View File

@ -0,0 +1,171 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include "tensorflow/c/kernels.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.h"
// BitcastOp implements a bitcast kernel, creating an output tensor that shares
// the same data buffer as the input but with a different shape and/or data
// type. Its inputs are:
//
// * the input tensor
// * an attribute named "T" containing the TF_DataType of the input tensor
// * an attribute named "type" containing the TF_DataType of the output tensor
//
// Given an input tensor of shape [...], if the input DataType "T" is larger
// than the output DataType "type", then the shape changes from [...]
// to [..., sizeof(T)/sizeof(type)].
//
// If "T" is smaller than "type", the operator requires that the rightmost
// dimension be equal to sizeof(type)/sizeof(T). The shape then goes from
// [..., sizeof(type)/sizeof(T)] to [...].
//
// Bitcast is implemented as a low-level cast, so machines with different endian
// orderings will give different results.
typedef struct BitcastOp {
TF_DataType input_data_type;
TF_DataType output_data_type;
size_t in_size;
size_t out_size;
} BitcastOp;
static void* BitcastOp_Create(TF_OpKernelConstruction* ctx) {
auto* kernel = new BitcastOp;
TF_Status* s = TF_NewStatus();
TF_OpKernelConstruction_GetAttrType(ctx, "T", &kernel->input_data_type, s);
if (TF_GetCode(s) == TF_OK) {
TF_OpKernelConstruction_GetAttrType(ctx, "type", &kernel->output_data_type,
s);
}
if (TF_GetCode(s) == TF_OK) {
kernel->in_size = TF_DataTypeSize(kernel->input_data_type);
kernel->out_size = TF_DataTypeSize(kernel->output_data_type);
size_t check_size = std::max(kernel->in_size, kernel->out_size) %
std::min(kernel->in_size, kernel->out_size);
if (check_size != 0) {
std::ostringstream err;
err << "cannot convert between datatype " << kernel->input_data_type
<< " and " << kernel->output_data_type;
TF_SetStatus(s, TF_INVALID_ARGUMENT, err.str().c_str());
}
}
if (TF_GetCode(s) != TF_OK) {
TF_OpKernelConstruction_Failure(ctx, s);
delete kernel;
kernel = nullptr;
}
TF_DeleteStatus(s);
return kernel;
}
static void BitcastOp_Delete(void* kernel) {
delete static_cast<BitcastOp*>(kernel);
}
static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
auto* k = static_cast<BitcastOp*>(kernel);
int dim_count = 0;
TF_Tensor* tensor;
TF_Status* status = TF_NewStatus();
TF_GetInput(ctx, 0, &tensor, status);
if (TF_GetCode(status) == TF_OK) {
dim_count = TF_NumDims(tensor);
if (!(k->in_size >= k->out_size ||
(dim_count > 0 &&
TF_Dim(tensor, dim_count - 1) == k->out_size / k->in_size))) {
std::ostringstream err;
err << "Cannot bitcast from " << k->input_data_type << " to "
<< k->output_data_type;
TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
}
}
if (TF_GetCode(status) == TF_OK) {
auto* dims = new int64_t[dim_count + 1];
int new_dim_count = dim_count;
for (int dim = 0; dim < dim_count; ++dim) {
dims[dim] = TF_Dim(tensor, dim);
}
if (k->out_size < k->in_size) {
dims[new_dim_count++] = static_cast<int64_t>(k->in_size / k->out_size);
} else if (k->out_size > k->in_size) {
--new_dim_count;
}
TF_Tensor* output = TF_AllocateTensor(k->output_data_type, dims, 0,
TF_DataTypeSize(k->output_data_type));
TF_TensorBitcastFrom(tensor, k->output_data_type, output, dims,
new_dim_count, status);
if (TF_GetCode(status) == TF_OK) {
TF_SetOutput(ctx, 0, output, status);
}
delete[] dims;
TF_DeleteTensor(output);
}
if (TF_GetCode(status) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status);
}
TF_DeleteStatus(status);
TF_DeleteTensor(tensor);
}
static void RegisterBitcastOp() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU,
&BitcastOp_Create, &BitcastOp_Compute,
&BitcastOp_Delete);
TF_RegisterKernelBuilder("BitcastOp", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering bitcast kernel";
}
#if GOOGLE_CUDA
{
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU,
&BitcastOp_Create, &BitcastOp_Compute,
&BitcastOp_Delete);
TF_RegisterKernelBuilder("BitcastOp", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering CUDA bitcast kernel";
}
#endif
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the bitcast kernel.
static bool BitcastOpIsRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("BitcastOp")) {
RegisterBitcastOp();
}
return true;
}();

View File

@ -0,0 +1,101 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
TensorShape expected_shape, error::Code expected_code) {
Status status;
NodeDef def;
def.set_op("Bitcast");
def.set_device(DEVICE_CPU);
AttrValue typeAttr;
SetAttrValue(input_tensor->dtype(), &typeAttr);
AttrValue outTypeAttr;
SetAttrValue(out_type, &outTypeAttr);
(*def.mutable_attr())["T"] = typeAttr;
(*def.mutable_attr())["type"] = outTypeAttr;
def.add_input(
strings::StrCat("input1: ", DataTypeString(input_tensor->dtype())));
std::unique_ptr<OpKernel> kernel =
CreateOpKernel(DeviceType(DEVICE_CPU), nullptr, nullptr, def, 1, &status);
ASSERT_TRUE(status.ok()) << status.ToString();
OpKernelContext::Params params;
DummyDevice dummy_device(nullptr, false);
params.device = &dummy_device;
params.op_kernel = kernel.get();
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.emplace_back(input_tensor);
params.inputs = &inputs;
OpKernelContext ctx(&params);
kernel->Compute(&ctx);
ASSERT_EQ(expected_code, ctx.status().code());
if (expected_code == error::OK) {
ASSERT_EQ(expected_shape, ctx.mutable_output(0)->shape())
<< ctx.mutable_output(0)->shape().DebugString();
}
}
TEST(BitcastOpTest, TestUpcast) {
Tensor int8_input(DT_UINT8, {8});
for (int i = 0; i < 8; i++) {
int8_input.vec<uint8>()(i) = static_cast<uint8>(1);
}
TestBitcastOp(&int8_input, DT_UINT64, TensorShape(), error::OK);
}
TEST(BitcastOpTest, TestDowncast) {
Tensor int64_input(static_cast<uint64>(1));
TestBitcastOp(&int64_input, DT_UINT8, TensorShape({8}), error::OK);
}
TEST(BitcastOpTest, TestCastToSameSize) {
Tensor int32_input(DT_UINT32, {4, 6});
TestBitcastOp(&int32_input, DT_UINT8, TensorShape({4, 6, 4}), error::OK);
}
TEST(BitcastOpTest, TestImpossibleCast) {
Tensor int8_input(DT_UINT8, {1});
TestBitcastOp(&int8_input, DT_UINT32, TensorShape(), error::INVALID_ARGUMENT);
}
} // namespace
} // namespace tensorflow

326
tensorflow/c/ops.cc Normal file
View File

@ -0,0 +1,326 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h"
using ::tensorflow::DataType;
using ::tensorflow::OpDef;
using ::tensorflow::OpDeprecation;
using ::tensorflow::OpShapeInferenceFn;
using ::tensorflow::Set_TF_Status_from_Status;
using ::tensorflow::Status;
using ::tensorflow::shape_inference::DimensionHandle;
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeHandle;
typedef struct TF_OpDefinitionBuilder {
// The op definition proto representing the op.
tensorflow::OpDef op_def;
// The shape inference function, or nullptr if none is provided for this op.
OpShapeInferenceFn shape_inference_func;
} TF_OpDefinitionBuilder;
TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
auto* result = new TF_OpDefinitionBuilder;
result->op_def.set_name(op_name);
return result;
}
void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
delete builder;
}
static void PopulateArg(OpDef::ArgDef* arg, const char* name,
TF_DataType type) {
arg->set_name(name);
arg->set_type(static_cast<DataType>(type));
}
void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
const char* name, TF_DataType type) {
PopulateArg(builder->op_def.add_input_arg(), name, type);
}
void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
const char* name, TF_DataType type) {
PopulateArg(builder->op_def.add_output_arg(), name, type);
}
#define DEFINE_BUILDER_BOOL_SETTER(func_name, builder_setter_name, arg_name) \
void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
bool arg_name) { \
builder->op_def.builder_setter_name(arg_name); \
}
DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative, set_is_commutative, is_commutative)
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate, set_is_aggregate, is_aggregate)
DEFINE_BUILDER_BOOL_SETTER(SetIsStateful, set_is_stateful, is_stateful)
DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput,
set_allows_uninitialized_input,
allows_unintialized_input)
static OpDef::AttrDef* AddAttribute(TF_OpDefinitionBuilder* builder,
const char* name, const char* type_name) {
OpDef::AttrDef* attr = builder->op_def.add_attr();
attr->set_name(name);
attr->set_type(type_name);
return attr;
}
#define DEFINE_ATTR_SETTER(attr_type, type_name, field_c_type, field_name) \
void TF_OpDefinitionBuilderAdd##attr_type##Attr( \
TF_OpDefinitionBuilder* builder, const char* name) { \
AddAttribute(builder, name, type_name); \
} \
\
void TF_OpDefinitionBuilderAdd##attr_type##AttrWithDefaultValue( \
TF_OpDefinitionBuilder* builder, const char* name, \
field_c_type field_name) { \
OpDef::AttrDef* attr = AddAttribute(builder, name, type_name); \
attr->mutable_default_value()->set_##field_name(field_name); \
} \
\
void TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
TF_OpDefinitionBuilder* builder, const char* name, \
field_c_type field_name[], size_t n) { \
OpDef::AttrDef* attr = AddAttribute(builder, name, "list(" type_name ")"); \
for (int _i = 0; _i < n; ++_i) { \
attr->mutable_default_value()->mutable_list()->add_##field_name( \
field_name[_i]); \
} \
} \
\
void TF_OpDefinitionBuilderAdd##attr_type##ListAttr( \
TF_OpDefinitionBuilder* builder, const char* name) { \
TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
builder, name, NULL, 0); \
}
DEFINE_ATTR_SETTER(String, "string", const char*, s)
DEFINE_ATTR_SETTER(Int, "int", int64_t, i)
DEFINE_ATTR_SETTER(Float, "float", float, f)
DEFINE_ATTR_SETTER(Bool, "bool", bool, b)
void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
int version, const char* explanation) {
OpDeprecation* dep = builder->op_def.mutable_deprecation();
dep->set_version(version);
dep->set_explanation(explanation);
}
void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
::tensorflow::OpRegistry::Global()->Register(
[builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
op_reg_data->op_def.Clear();
op_reg_data->op_def.MergeFrom(builder->op_def);
op_reg_data->shape_inference_fn = builder->shape_inference_func;
return Status::OK();
});
// Calling ProcessRegistrations ensures that the cc_builder's finalize method
// is called and that the builder can be deleted.
Set_TF_Status_from_Status(
status, ::tensorflow::OpRegistry::Global()->ProcessRegistrations());
delete builder;
}
void TF_OpDefinitionBuilderSetShapeInferenceFunction(
TF_OpDefinitionBuilder* builder,
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
TF_Status* status)) {
builder->shape_inference_func =
[shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
TF_Status* c_status = TF_NewStatus();
auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
shape_inference_func(c_ctx, c_status);
tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
TF_DeleteStatus(c_status);
return result;
};
}
TF_ShapeHandle* TF_NewShapeHandle() {
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
}
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
TF_ShapeInferenceContext* ctx, size_t size) {
auto* handle = new ShapeHandle;
*handle = reinterpret_cast<InferenceContext*>(ctx)->Vector(size);
return reinterpret_cast<TF_ShapeHandle*>(handle);
}
void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* first,
TF_ShapeHandle* second,
TF_ShapeHandle* result,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
Status s = cc_ctx->Concatenate(*reinterpret_cast<ShapeHandle*>(first),
*reinterpret_cast<ShapeHandle*>(second),
reinterpret_cast<ShapeHandle*>(result));
Set_TF_Status_from_Status(status, s);
}
TF_DimensionHandle* TF_NewDimensionHandle() {
return reinterpret_cast<TF_DimensionHandle*>(new DimensionHandle);
}
int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) {
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
return cc_ctx->num_inputs();
}
void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i,
TF_ShapeHandle* handle,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
if (0 < i || i >= cc_ctx->num_inputs()) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range");
}
if (TF_GetCode(status) == TF_OK) {
auto* cc_result = reinterpret_cast<ShapeHandle*>(handle);
*cc_result = cc_ctx->input(i);
}
}
int TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* handle) {
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
return cc_ctx->RankKnown(*reinterpret_cast<ShapeHandle*>(handle));
}
void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i,
TF_ShapeHandle* handle,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
if (0 < i || i >= cc_ctx->num_outputs()) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range");
}
if (TF_GetCode(status) == TF_OK) {
cc_ctx->set_output(i, *(reinterpret_cast<ShapeHandle*>(handle)));
}
}
void TF_DeleteShapeHandle(TF_ShapeHandle* handle) {
if (handle == nullptr) {
return;
}
delete reinterpret_cast<ShapeHandle*>(handle);
}
void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) {
if (handle == nullptr) {
return;
}
delete reinterpret_cast<DimensionHandle*>(handle);
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
void TF_ShapeInferenceContext_GetAttr##func( \
TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \
TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
cc_type v; \
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
Status s = cc_ctx->GetAttr(attr_name, &v); \
Set_TF_Status_from_Status(status, s); \
if (s.ok()) { \
*val = static_cast<c_type>(v); \
} \
}
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
#define DEFINE_RANK_FUNC(func_name) \
void TF_ShapeInferenceContext##func_name( \
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \
TF_ShapeHandle* result, TF_Status* status) { \
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
auto* cc_handle = reinterpret_cast<ShapeHandle*>(handle); \
auto* cc_result = reinterpret_cast<ShapeHandle*>(result); \
Status s = cc_ctx->func_name(*cc_handle, rank, cc_result); \
Set_TF_Status_from_Status(status, s); \
}
DEFINE_RANK_FUNC(WithRank)
DEFINE_RANK_FUNC(WithRankAtLeast)
DEFINE_RANK_FUNC(WithRankAtMost)
int64_t TF_ShapeInferenceContextRank(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* handle) {
return reinterpret_cast<InferenceContext*>(ctx)->Rank(
*reinterpret_cast<ShapeHandle*>(handle));
}
void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape_handle, int64_t i,
TF_DimensionHandle* result) {
int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle);
auto* cc_result = reinterpret_cast<DimensionHandle*>(result);
if (i < -rank || i >= rank) {
*cc_result = DimensionHandle();
return;
}
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
auto* cc_shape_handle = reinterpret_cast<ShapeHandle*>(shape_handle);
*cc_result = cc_ctx->Dim(*cc_shape_handle, i);
}
int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) {
return InferenceContext::ValueKnown(
*reinterpret_cast<DimensionHandle*>(dim_handle));
}
void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
Status s = ::tensorflow::shape_inference::UnknownShape(
reinterpret_cast<InferenceContext*>(ctx));
Set_TF_Status_from_Status(status, s);
}
void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape_handle,
int64_t start, int64_t end,
TF_ShapeHandle* result,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
auto* cc_result = reinterpret_cast<ShapeHandle*>(result);
Status s = cc_ctx->Subshape(*reinterpret_cast<ShapeHandle*>(shape_handle),
start, end, cc_result);
Set_TF_Status_from_Status(status, s);
}
int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) {
return InferenceContext::Value(
*reinterpret_cast<DimensionHandle*>(dim_handle));
}

407
tensorflow/c/ops.h Normal file
View File

@ -0,0 +1,407 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Routines for registering new ops and for implementing op shape inference
// functions.
//
// This API is alpha software and is subject to change.
//
// REGISTRATION
// ------------
//
// In order to register a new op, create a new TF_OpDefinitionBuilder:
//
// TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("OpName");
//
// Inputs, outputs and attributes can be added to the builder with the
// corresponding functions, e.g.
//
// TF_OpDefinitionBuilderAddInput(builder, "input1: int32");
// TF_OpDefinitionBuilderAddOutput(builder, "output1: int64");
// TF_OpDefinitionBuilderAddAttr(builder, "attr: int32");
//
// The builder may then be registered with TensorFlow using the
// TF_RegisterOpDefinition function. E.g.
//
// TF_Status* status = TF_NewStatus();
// TF_RegisterOpDefinition(builder, &status);
// if (TF_GetCode(status) != TF_OK) {
// // handle error
// }
//
// SHAPE INFERENCE
// ---------------
//
// You can provide a shape inference function that TensorFlow will call when it
// wants to understand the shape of outputs that the op will produce. Use the
// TF_OpDefinitionBuilderSetShapeInferenceFunction function to register a shape
// inference function pointer with TensorFlow. The following is an example of a
// very simple shape inference function:
//
// void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
// TF_ShapeHandle* input = TF_NewShapeHandle();
// TF_ShapeInferenceContextGetInput(ctx, 0, input, status);
// if (TF_GetCode(status) == TF_OK) {
// TF_ShapeInferenceContextSetOutput(ctx, 0, input, status);
// }
// TF_DeleteShapeHandle(input);
// }
//
// The following code registers the inference function with TensorFlow:
//
// TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
//
// For more details about shape inference, see the documentation for
// TF_OpDefinitionBuilderSetShapeInferenceFunction.
#ifndef TENSORFLOW_C_OPS_H_
#define TENSORFLOW_C_OPS_H_
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#include "tensorflow/c/c_api.h"
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
struct TF_DimensionHandle;
struct TF_OpDefinitionBuilder;
struct TF_ShapeHandle;
struct TF_ShapeInferenceContext;
// Returns a newly allocated op definition builder for the given op name. The
// returned builder may be customized with the `TF_OpDefinitionBuilder...`
// functions and then registered with TensorFlow with TF_RegisterOpDefinition.
//
// The returned pointer is either freed by a call to TF_RegisterOpDefinition, or
// can be manually deleted by TF_DeleteOpDefinitionBuilder if it is never
// registered.
TF_CAPI_EXPORT extern TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(
const char* op_name);
// Registers the given op builder with TensorFlow. Indicates success or
// otherwise in the given status.
//
// `builder` is freed whether the op was successfully registered or not. You
// must call either this function or TF_DeleteOpDefinitionBuilder to free the
// builder, but never both.
TF_CAPI_EXPORT extern void TF_RegisterOpDefinition(
TF_OpDefinitionBuilder* builder, TF_Status* status);
// Frees the given op definition builder. You must call either this function or
// TF_RegisterOpDefinition to free the builder, but never both.
TF_CAPI_EXPORT extern void TF_DeleteOpDefinitionBuilder(
TF_OpDefinitionBuilder* builder);
//----------------------------------------------------
// Attribute functions.
// Adds a string attribute with the given name to the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a string attribute with the given name and default value to the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, const char* value);
// Adds a string list attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a string list attribute with the given default values to the builder.
// `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddStringListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, const char* values[],
size_t n);
// Adds an integer attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds an integer attribute with the given name and default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, int64_t value);
// Adds an integer list attribute with the given name and no default value to
// the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds an integer list attribute with the given name and default values to the
// builder. `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddIntListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, int64_t values[],
size_t n);
// Adds a float attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a float attribute with the given name and default value to the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, float value);
// Adds a float list attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a float list attribute with the given name and default values to the
// builder. `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, float values[],
size_t n);
// Adds a boolean attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a boolean attribute with the given name and default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, bool value);
// Adds a boolean list attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a boolean list attribute with the given name and default values to the
// builder. `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddBoolListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, bool values[], size_t n);
// Adds the input with the given name and type to the op.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddInput(
TF_OpDefinitionBuilder* builder, const char* name, TF_DataType type);
// Adds the output with the given name and type to the op.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddOutput(
TF_OpDefinitionBuilder* builder, const char* output, TF_DataType type);
// Sets the commutative property for the op built by the given builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsCommutative(
TF_OpDefinitionBuilder* builder, bool is_commutative);
// Sets the is_aggregate property of the builder to the given value.
//
// If is_aggregate is true, then the operation produced by this builder accepts
// N >= 2 inputs and produces 1 output all of the same type. Should be
// associative and commutative, and produce output with the same shape as the
// input. The optimizer may replace an aggregate op taking input from multiple
// devices with a tree of aggregate ops that aggregate locally within each
// device (and possibly within groups of nearby devices) before communicating.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsAggregate(
TF_OpDefinitionBuilder* builder, bool is_aggregate);
// Sets the is_stateful property of the builder to the given value.
//
// The op built by this builder is stateful if its behavior depends on some
// state beyond its input tensors (e.g. variable reading op) or if it has a
// side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
// must always produce the same output for the same input and have no
// side-effects.
//
// By default Ops may be moved between devices. Stateful ops should either not
// be moved, or should only be moved if that state can also be moved (e.g. via
// some sort of save / restore). Stateful ops are guaranteed to never be
// optimized away by Common Subexpression Elimination (CSE).
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsStateful(
TF_OpDefinitionBuilder* builder, bool is_stateful);
// Sets the allows_uninitialized_input property of the operation built by this
// builder.
//
// By default, all inputs to an Op must be initialized Tensors. Ops that may
// initialize tensors for the first time should set this field to true, to allow
// the Op to take an uninitialized Tensor as input.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetAllowsUninitializedInput(
TF_OpDefinitionBuilder* builder, bool allows_uninitialized_input);
// Adds a deprecation warning for the given op. This indicates to the user that
// `version` is the first TensorFlow GraphDef version for which the operation is
// deprecated. `explanation` should contain the reason for the deprecation and
// what to use instead.
//
// This function is only an indicator that the operation may disappear in a
// version of TensorFlow after `version`. It does not affect op registration.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderDeprecated(
TF_OpDefinitionBuilder* builder, int version, const char* explanation);
// Sets the shape inference function for the op.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetShapeInferenceFunction(
TF_OpDefinitionBuilder* builder,
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
TF_Status* status));
//----------------------------------------------------
// Functions for TF_ShapeInferenceContext.
//
// Functions for implementing shape inference functions. TensorFlow uses these
// functions to determine the shape of tensors produced by an operation without
// having to actually run the operation. If an operation chooses to provide a
// shape inference function, it will be invoked by TensorFlow as needed.
//
// When invoked by TensorFlow, the shape inference function is provided with a
// TF_ShapeInferenceContext pointer. The function's implementation will use the
// accessor and mutator functions with names beginning with
// TF_ShapeInferenceContext to examine the input state and determine the output
// shape.
// Returns the number of inputs in the given shape inference context.
TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextNumInputs(
TF_ShapeInferenceContext* ctx);
// Returns a newly allocated shape handle. The shapes represented by these
// handles may be queried or mutated with the corresponding
// TF_ShapeInferenceContext... functions.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_NewShapeHandle();
// Places the ith input of the given shape inference context into the given
// shape handle, or returns a status other than TF_OK indicating why the input
// could not be retrieved
// (for example, if i < 0 || i >= TF_ShapeInferenceContextNumInputs(ctx)).
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextGetInput(
TF_ShapeInferenceContext* ctx, int i, TF_ShapeHandle* handle,
TF_Status* status);
// Places the given shape handle into the `i`th output position of the given
// context. Internally, the shape handle is copied; the caller may subsequently
// delete `handle`.
TF_CAPI_EXPORT
extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx,
int i, TF_ShapeHandle* handle,
TF_Status* status);
// Returns a newly-allocate shape handle representing a vector of the given
// size. The returned handle should be freed with TF_DeleteShapeHandle.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
TF_ShapeInferenceContext* ctx, size_t size);
// Returns a newly allocated dimension handle. It must be freed with
// TF_DeleteDimensionHandle.
TF_CAPI_EXPORT extern TF_DimensionHandle* TF_NewDimensionHandle();
// Interprets the named shape inference context attribute as a TF_DataType and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// TF_DataType, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContext_GetAttrType(
TF_ShapeInferenceContext* ctx, const char* attr_name, TF_DataType* val,
TF_Status* status);
// Returns the rank of the shape represented by the given handle.
TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextRank(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle);
// Returns 1 if `handle` has a known rank, 0 otherwise.
TF_CAPI_EXPORT extern int TF_ShapeInferenceContextRankKnown(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle);
// If <handle> has rank <rank>, or its rank is unknown, return OK and return the
// shape with asserted rank in <*result>. Otherwise an error is placed into
// `status`.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRank(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
TF_ShapeHandle* result, TF_Status* status);
// If <handle> has rank at least <rank>, or its rank is unknown, return OK and
// return the shape with asserted rank in <*result>. Otherwise an error is
// placed into `status`.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtLeast(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
TF_ShapeHandle* result, TF_Status* status);
// If <handle> has rank at most <rank>, or its rank is unknown, return OK and
// return the shape with asserted rank in <*result>. Otherwise an error is
// placed into `status`.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtMost(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank,
TF_ShapeHandle* result, TF_Status* status);
// Places a handle to the ith dimension of the given shape into *result.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextDim(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t i,
TF_DimensionHandle* result);
// Returns 1 if the given handle represents a known dimension.
TF_CAPI_EXPORT extern int TF_ShapeInferenceContextDimValueKnown(
TF_ShapeInferenceContext* ctx, TF_DimensionHandle* handle);
// Returns in <*result> a sub-shape of <shape_handle>, with dimensions
// [start:end]. <start> and <end> can be negative, to index from the end of the
// shape. <start> and <end> are set to the rank of <shape_handle> if > rank of
// <shape_handle>.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSubshape(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t start,
int64_t end, TF_ShapeHandle* result, TF_Status* status);
// Places an unknown shape in all outputs for the given inference context. Used
// for shape inference functions with ops whose output shapes are unknown.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSetUnknownShape(
TF_ShapeInferenceContext* ctx, TF_Status* status);
// Returns whether the given handle represents a known dimension.
TF_CAPI_EXPORT extern int TF_DimensionHandleValueKnown(
TF_DimensionHandle* dim_handle);
// Returns the value of the given dimension.
TF_CAPI_EXPORT extern int64_t TF_DimensionHandleValue(
TF_DimensionHandle* dim_handle);
// Returns in <*result> the result of appending the dimensions of <second> to
// those of <first>.
TF_CAPI_EXPORT extern void TF_ShapeInferenceContextConcatenateShapes(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* first,
TF_ShapeHandle* second, TF_ShapeHandle* result, TF_Status* status);
// Frees the given shape handle.
TF_CAPI_EXPORT extern void TF_DeleteShapeHandle(TF_ShapeHandle* handle);
// Frees the given dimension handle.
TF_CAPI_EXPORT extern void TF_DeleteDimensionHandle(TF_DimensionHandle* handle);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_OPS_H_

159
tensorflow/c/ops_test.cc Normal file
View File

@ -0,0 +1,159 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(OpsTest, TestBasicOpRegistration) {
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp");
TF_OpDefinitionBuilderAddStringAttr(builder, "attr1");
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
TF_OpDefinitionBuilderAddInput(builder, "input2", TF_UINT16);
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT32);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Buffer* op_list_buffer = TF_GetAllOpList();
::tensorflow::OpList op_list;
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false;
for (const auto& op : op_list.op()) {
if (op.name() == "SomeOp") {
ASSERT_EQ(2, op.input_arg_size());
ASSERT_EQ("input1", op.input_arg(0).name());
ASSERT_EQ(::tensorflow::DT_UINT8, op.input_arg(0).type());
ASSERT_EQ(1, op.attr_size());
ASSERT_EQ("string", op.attr(0).type());
found = true;
}
}
EXPECT_TRUE(found);
TF_DeleteStatus(status);
TF_DeleteBuffer(op_list_buffer);
}
void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
TF_ShapeHandle* handle = TF_NewShapeHandle();
TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
TF_ShapeInferenceContextSetOutput(ctx, 0, handle, status);
TF_DeleteShapeHandle(handle);
}
TEST(OpsTest, TestShapeInference_IdentityFunction) {
ShapeInferenceTestOp op("SomeTestOp");
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp");
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_ASSERT_OK(
shape_inference::ShapeInferenceTestutil::InferShapes(op, "[1,2]", "in0"));
TF_DeleteStatus(status);
}
// Creates an output whose shape is a vector of length
// TF_ShapeInferenceContextRank.
void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
TF_ShapeHandle* handle = TF_NewShapeHandle();
TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
TF_ShapeHandle* new_shape = TF_ShapeInferenceContextVectorFromSize(
ctx, TF_ShapeInferenceContextRank(ctx, handle));
TF_ShapeInferenceContextSetOutput(ctx, 0, new_shape, status);
TF_DeleteShapeHandle(handle);
TF_DeleteShapeHandle(new_shape);
}
TEST(OpsTest, TestShapeInference_VectorizeFunction) {
ShapeInferenceTestOp op("VectorizeTestOp");
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("VectorizeTestOp");
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
op, "[4,5,9]", "[3]"));
TF_DeleteStatus(status);
}
TEST(OpsTest, AttributeAccessors) {
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
float values[] = {1, 2, 3, 4};
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
builder, "foo1", values, sizeof(values));
TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(builder, "foo2",
"my string");
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
TF_OpDefinitionBuilderSetIsAggregate(builder, true);
TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true);
std::string deprecation_msg = "use something else instead";
TF_OpDefinitionBuilderDeprecated(builder, 4, deprecation_msg.c_str());
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
TF_Buffer* op_list_buffer = TF_GetAllOpList();
::tensorflow::OpList op_list;
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
bool found = false;
for (const auto& op : op_list.op()) {
if (op.name() == "AttributeAccesorsOp") {
ASSERT_TRUE(op.is_commutative());
ASSERT_TRUE(op.is_aggregate());
ASSERT_TRUE(op.allows_uninitialized_input());
ASSERT_EQ(4, op.deprecation().version());
ASSERT_EQ(deprecation_msg, op.deprecation().explanation());
ASSERT_EQ(2, op.attr_size());
ASSERT_EQ("list(float)", op.attr(0).type());
AttrValue::ListValue l = op.attr(0).default_value().list();
ASSERT_EQ(1, l.f(0));
ASSERT_EQ(2, l.f(1));
ASSERT_EQ(3, l.f(2));
ASSERT_EQ(4, l.f(3));
ASSERT_EQ("string", op.attr(1).type());
ASSERT_EQ("my string", op.attr(1).default_value().s());
found = true;
}
}
ASSERT_TRUE(found);
TF_DeleteStatus(status);
TF_DeleteBuffer(op_list_buffer);
}
} // namespace
} // namespace tensorflow

View File

@ -89,7 +89,7 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
if (status->status.ok()) {
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
@ -163,7 +163,7 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
tensorflow::shape_inference::ShapeHandle shape;
status->status =
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
if (!status->status.ok()) return;
if (TF_GetCode(status) != TF_OK) return;
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
}
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
@ -174,7 +174,7 @@ void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
mutex_lock l(graph->mu);
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
new_src.index, &dst->node);
if (status->status.ok()) {
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.

View File

@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_ATTRTYPE_H_
#define TENSORFLOW_C_TF_ATTRTYPE_H_
#ifdef __cplusplus
extern "C" {
#endif
// TF_AttrType describes the type of the value of an attribute on an operation.
typedef enum TF_AttrType {
TF_ATTR_STRING = 0,
TF_ATTR_INT = 1,
TF_ATTR_FLOAT = 2,
TF_ATTR_BOOL = 3,
TF_ATTR_TYPE = 4,
TF_ATTR_SHAPE = 5,
TF_ATTR_TENSOR = 6,
TF_ATTR_PLACEHOLDER = 7,
TF_ATTR_FUNC = 8,
} TF_AttrType;
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_TF_ATTRTYPE_H_

View File

@ -8,6 +8,19 @@ package(
licenses(["notice"]) # Apache 2.0
filegroup(
name = "srcs",
srcs = [
"framework/gradients.h",
"framework/ops.h",
"framework/scope.h",
"framework/scope_internal.h",
"ops/array_ops.h",
"ops/while_loop.h",
"//tensorflow/cc/saved_model:loader.h",
],
)
load(
"//tensorflow:tensorflow.bzl",
"cc_library_with_android_deps",
@ -190,6 +203,7 @@ tf_cc_test(
deps = [
":ops",
":scope",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@ -606,16 +620,13 @@ tf_gen_op_wrappers_cc(
visibility = ["//tensorflow:internal"],
)
cc_library_with_android_deps(
cc_library(
name = "cc_op_gen_main",
srcs = [
"framework/cc_op_gen.cc",
"framework/cc_op_gen.h",
"framework/cc_op_gen_main.cc",
],
android_deps = [
"//tensorflow/core:android_tensorflow_lib",
],
copts = tf_copts(),
data = [
"//tensorflow/core/api_def:base_api_def",

View File

@ -42,14 +42,19 @@ namespace {
const int kRightMargin = 79;
// Converts:
// bazel-out/.../genfiles/(external/YYY/)?XX
// bazel-out/.../(bin|genfiles)/(external/YYY/)?XX
// to: XX.
string GetPath(const string& dot_h_fname) {
auto pos = dot_h_fname.find("/genfiles/");
auto pos = dot_h_fname.find("/bin/");
string result = dot_h_fname;
if (pos != string::npos) {
// - 1 account for the terminating null character (\0) in "/genfiles/".
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
result = dot_h_fname.substr(pos + sizeof("/bin/") - 1);
} else {
pos = dot_h_fname.find("/genfiles/");
if (pos != string::npos) {
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
}
}
if (result.size() > sizeof("external/") &&
result.compare(0, sizeof("external/") - 1, "external/") == 0) {

View File

@ -531,4 +531,23 @@ Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
return InternalScope::NewScope(graph, status, refiner);
}
Status CreateOutputWithScope(string op_name,
absl::Span<const ::tensorflow::Input> inputs,
const Scope& scope, Output* output) {
TF_RETURN_IF_ERROR(scope.status());
const auto unique_name = scope.GetUniqueNameForOp(op_name);
auto builder = ::tensorflow::NodeBuilder(unique_name, op_name);
for (auto input : inputs) {
TF_RETURN_IF_ERROR(scope.status());
builder = builder.Input(input.node());
}
::tensorflow::Node* ret;
scope.UpdateBuilder(&builder);
TF_RETURN_IF_ERROR(scope.status());
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
TF_RETURN_IF_ERROR(scope.status());
*output = Output(ret, 0);
return Status::OK();
}
} // namespace tensorflow

View File

@ -255,6 +255,12 @@ struct CompositeOpScopes {
Scope last;
};
// Creates a node of the given operation, with the given inputs, and assigns the
// result to output. This does not support the ability to add additional
// attributes.
Status CreateOutputWithScope(string op_name,
absl::Span<const ::tensorflow::Input> inputs,
const Scope& scope, Output* output);
/// @}
} // namespace tensorflow

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -145,4 +147,14 @@ TEST(ScopeTest, ControlDeps) {
EXPECT_EQ(c_c.control_deps().size(), 3);
}
TEST(ScopeTest, CreateOutput) {
Scope root = Scope::NewRootScope();
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
Output add;
ASSERT_TRUE(
CreateOutputWithScope("Add", {a, a}, root.WithOpName("add"), &add).ok());
EXPECT_EQ(add.node()->name(), "add");
EXPECT_EQ(add.node()->type_string(), "Add");
}
} // namespace tensorflow

View File

@ -88,15 +88,19 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op,
string kernel_type;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type));
bool antialias;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "antialias", &antialias));
grad_outputs->push_back(internal::ScaleAndTranslateGrad(
scope, grad_inputs[0], op.input(0), op.input(2), op.input(3),
internal::ScaleAndTranslateGrad::KernelType(kernel_type)));
internal::ScaleAndTranslateGrad::KernelType(kernel_type)
.Antialias(antialias)));
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper);
Status CropAndResizeGradHelper(const Scope& scope, const Operation& op,

View File

@ -196,29 +196,106 @@ class ScaleAndTranslateGradTest : public ::testing::Test {
}
template <typename T>
void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x,
Output* y) {
void MakeOp(const Tensor& x_data, const Input& y_shape, Input scale,
Input translation, const string& kernel_type, bool antialias,
Output* x, Output* y) {
*x = Const<T>(scope_, x_data);
*y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f});
*y = ScaleAndTranslate(scope_, *x, y_shape, scale, translation,
ScaleAndTranslate::KernelType(kernel_type)
.Antialias(antialias)
.Antialias(antialias));
TF_ASSERT_OK(scope_.status());
}
template <typename X_T, typename Y_T, typename JAC_T>
void TestResize() {
TensorShape x_shape({1, 2, 3, 1});
void TestScaleAndTranslate(const TensorShape x_shape, const int out_height,
const int out_width, Input scale,
Input translation, const string& kernel_type,
bool antialias) {
Tensor x_data = MakeData<X_T>(x_shape);
Output x, y;
MakeOp<X_T>(x_data, {4, 6}, &x, &y);
MakeOp<X_T>(x_data, {out_height, out_width}, scale, translation,
kernel_type, antialias, &x, &y);
JAC_T max_error;
TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
scope_, x, x_data, y, {1, 4, 6, 1}, &max_error)));
EXPECT_LT(max_error, 1e-3);
scope_, x, x_data, y, {1, out_height, out_width, 1}, &max_error)));
EXPECT_LT(max_error, 2e-3);
}
const std::vector<Input> kScales = {Input{1.0f, 1.0f}, Input{0.37f, 0.47f},
Input{2.1f, 2.1f}};
const std::vector<Input> kTranslations = {
Input{0.0f, 0.0f}, Input{3.14f, 1.19f}, Input{2.1f, 3.1f},
Input{100.0f, 200.0f}};
Scope scope_;
};
TEST_F(ScaleAndTranslateGradTest, Works) { TestResize<float, float, float>(); }
TEST_F(ScaleAndTranslateGradTest, TestGrads) {
const std::vector<std::string> kKernelTypes = {"lanczos1", "lanczos3",
"lanczos5", "gaussian"};
constexpr int kOutHeight = 4;
constexpr int kOutWidth = 6;
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
for (const std::string& kernel_type : kKernelTypes) {
TestScaleAndTranslate<float, float, float>(
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
true);
}
}
}
}
TEST_F(ScaleAndTranslateGradTest, TestGradsWithoutAntialias) {
constexpr int kOutHeight = 4;
constexpr int kOutWidth = 6;
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
TestScaleAndTranslate<float, float, float>(kXShape, kOutHeight, kOutWidth,
scale, translation, "lanczos3",
false);
}
}
}
TEST_F(ScaleAndTranslateGradTest, TestGradsWithSameShape) {
const std::vector<std::string> kKernelTypes = {"lanczos3", "gaussian"};
constexpr int kOutHeight = 2;
constexpr int kOutWidth = 3;
const TensorShape kXShape = TensorShape({1, 2, 3, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
for (const std::string& kernel_type : kKernelTypes) {
TestScaleAndTranslate<float, float, float>(
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
true);
}
}
}
}
TEST_F(ScaleAndTranslateGradTest, TestGradsWithSmallerShape) {
const std::vector<std::string> kKernelTypes = {"lanczos3", "gaussian"};
constexpr int kOutHeight = 2;
constexpr int kOutWidth = 3;
const TensorShape kXShape = TensorShape({1, 4, 6, 1});
for (const Input scale : kScales) {
for (const Input translation : kTranslations) {
for (const std::string& kernel_type : kKernelTypes) {
TestScaleAndTranslate<float, float, float>(
kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type,
true);
}
}
}
}
class CropAndResizeGradTest : public ::testing::Test {
protected:
@ -237,9 +314,9 @@ class CropAndResizeGradTest : public ::testing::Test {
template <typename T>
void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind,
const Input& crop_szie, Output* x, Output* y) {
const Input& crop_size, Output* x, Output* y) {
*x = Const<T>(scope_, x_data);
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie,
*y = CropAndResize(scope_, *x, boxes, box_ind, crop_size,
CropAndResize::Method("bilinear"));
TF_ASSERT_OK(scope_.status());
}

View File

@ -17,6 +17,11 @@ load(
"if_not_mobile",
"tf_cc_test",
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
"if_static_and_not_mobile",
)
cc_library(
name = "constants",
@ -78,12 +83,13 @@ cc_library(
hdrs = ["loader.h"],
deps = [
":loader_lite",
] + if_not_mobile([
] + if_static_and_not_mobile([
"//tensorflow/core:tensorflow",
]) + if_not_mobile([
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
]) + if_android([
"//tensorflow/core:android_tensorflow_lib",
]),
@ -91,6 +97,19 @@ cc_library(
cc_library(
name = "loader_lite",
hdrs = ["loader.h"],
deps = if_static([
":loader_lite_impl",
]) + if_not_mobile([
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
# mobile not supported yet
]),
)
cc_library(
name = "loader_lite_impl",
srcs = ["loader.cc"],
hdrs = ["loader.h"],
deps = [
@ -121,6 +140,7 @@ tf_cc_test(
":tag_constants",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",

View File

@ -148,7 +148,8 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
const std::vector<AssetFileDef>& asset_file_defs,
Session* session, const string& init_op_name) {
if (!init_op_name.empty()) {
LOG(INFO) << "Running initialization op on SavedModel bundle.";
LOG(INFO) << "Running initialization op on SavedModel bundle at path: "
<< export_dir;
std::vector<std::pair<string, Tensor>> inputs;
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;

View File

@ -36,6 +36,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",

View File

@ -18,27 +18,41 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
import logging as _logging
import os as _os
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
# pylint: disable=g-bad-import-order
# API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorboard.summary._tf.summary'),
error_msg=(
"Limited tf.compat.v2.summary API due to missing TensorBoard "
"installation"))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api._v2.estimator'))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v2.keras'))
# Hook external TensorFlow modules.
_current_module = _sys.modules[__name__]
try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
except ImportError:
_logging.warning(
"Limited tf.compat.v2.summary API due to missing TensorBoard "
"installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
#

View File

@ -19,19 +19,30 @@ from __future__ import division as _division
from __future__ import print_function as _print_function
import os as _os
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# API IMPORTS PLACEHOLDER
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=(
'tensorflow_estimator.python.estimator.api._v1.estimator'))
_component_api_helper.package_hook(
parent_package_str=__name__,
child_package_str=('tensorflow.python.keras.api._v1.keras'))
# Hook external TensorFlow modules.
_current_module = _sys.modules[__name__]
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
except ImportError:
pass
try:
from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
except ImportError:
pass
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable

View File

@ -33,13 +33,13 @@ cc_library(
":aot_only_var_handle_op",
":embedded_protocol_buffers",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:cpu_function_runtime",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@ -70,6 +70,7 @@ tf_cc_test(
],
deps = [
":tfcompile_lib",
"//tensorflow/compiler/xla:cpu_function_runtime",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -25,8 +25,8 @@ limitations under the License.
#include "absl/strings/str_replace.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -38,7 +38,7 @@ namespace tfcompile {
namespace {
using BufferInfo = cpu_function_runtime::BufferInfo;
using BufferInfo = xla::cpu_function_runtime::BufferInfo;
bool IsAlpha(char c) {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
@ -213,7 +213,11 @@ Status GenResultMethods(const tf2xla::Config& config,
return errors::Internal("codegen requires the XLA result to be a tuple");
}
size_t num_results = ps.result().tuple_shapes_size();
if (config.fetch_size() + config.variable_size() != num_results) {
int readonly_variables = absl::c_count_if(
config.variable(),
[](const tf2xla::Variable& var) { return var.readonly(); });
if (config.fetch_size() + config.variable_size() - readonly_variables !=
num_results) {
return errors::InvalidArgument("mismatch between fetch_size(",
config.fetch_size(), ")+variable_size(",
config.variable_size(), ") and tuple_size(",
@ -256,36 +260,26 @@ Status GenVariableMethods(const tf2xla::Config& config,
TF_RETURN_IF_ERROR(
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
const string code = R"(
void set_var_{{NAME}}_data({{TYPE}}* data) {
void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) {
set_arg_data({{I}}, data);
}
)";
const tf2xla::Variable& var = config.variable(i - config.feed_size());
*methods += RewriteWithName(
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
{{MAYBE_CONST}}{{TYPE}}* var_{{NAME}}_data() {
return static_cast<{{MAYBE_CONST}}{{TYPE}}*>(arg_data({{I}}));
}
size_t num_results = ps.result().tuple_shapes_size();
for (int i = config.fetch_size(); i < num_results; ++i) {
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(AddRewritesForShape(
i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
string code = R"(
{{TYPE}}* var_{{NAME}}_data() {
return static_cast<{{TYPE}}*>(result_data({{I}}));
}
{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
result_data({{I}}))){{INDICES}};
{{MAYBE_CONST}}{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{MAYBE_CONST}}{{TYPE}}(*){{DIM_SIZES}}>(
arg_data({{I}}))){{INDICES}};
}
const {{TYPE}}* var_{{NAME}}_data() const {
return static_cast<const {{TYPE}}*>(result_data({{I}}));
return static_cast<const {{TYPE}}*>(arg_data({{I}}));
}
const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const {
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
result_data({{I}}))){{INDICES}};
arg_data({{I}}))){{INDICES}};
}
)";
const tf2xla::Variable& var = config.variable(i - config.fetch_size());
const tf2xla::Variable& var = config.variable(i - config.feed_size());
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
*methods += RewriteWithName(
var.name().empty() ? var.node_name() : var.name(), code, rewrites);
}
@ -363,7 +357,7 @@ std::vector<string> BufferInfosToCppExpression(
? "~0ULL"
: absl::StrCat(encoded.second, "ULL");
return absl::StrCat(
"::tensorflow::cpu_function_runtime::BufferInfo({",
"::xla::cpu_function_runtime::BufferInfo({",
encoded.first, "ULL, ", encoded_second_as_str, "})");
});
return buffer_infos_as_strings;
@ -398,13 +392,15 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable));
const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
buffer_infos_for_args.data(), buffer_infos_for_args.size(),
/*allocate_entry_params=*/true);
const size_t arg_bytes_aligned =
xla::cpu_function_runtime::AlignedBufferBytes(
buffer_infos_for_args.data(), buffer_infos_for_args.size(),
/*allocate_entry_params=*/true);
const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
/*allocate_entry_params=*/true);
const size_t temp_bytes_aligned =
xla::cpu_function_runtime::AlignedBufferBytes(
buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
/*allocate_entry_params=*/true);
const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
// Create rewrite strings for namespace start and end.
@ -538,7 +534,8 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
return *kStaticData;
}
{{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
{{CLASS}}(AllocMode alloc_mode =
AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
{{CLASS}}(const {{CLASS}}&) = delete;
@ -579,27 +576,37 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
// buffers are managed internally, and may change after each call to Run.
{{METHODS_RESULT}}
// Methods for managing variable buffers. Buffers are in row-major order. The
// input and output buffers may or may not be identical.
// Methods for managing variable buffers. Buffers are in row-major order.
//
// For read-write variables we generate the following methods:
//
// void set_var_X_data(T* data)
// Sets the buffer for variable X.
// Sets the buffer for variable X. Must be called before Run if the
// allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
//
// T* var_X_data()
// Returns the buffer of type T for variable X.
// Returns the buffer of type T for variable X. If the allocation mode is
// RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
// buffer passed to set_var_X_data.
//
// T& var_X(...dim indices...)
// Returns a reference to the value of type T for variable X,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
// For readonly variables we generate the same set of methods, except that we
// use `const T` instead of `T`. We use `const T` to avoid erasing the
// constness of the buffer passed to `set_var_X_data` but the underlying
// buffer is not const (and thus the const can be safely const-cast'ed away)
// unless `set_var_X_data` is called with a pointer to constant storage.
{{METHODS_VARIABLE}}
private:
// Number of buffers for the compiled computation.
static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::tensorflow::cpu_function_runtime::BufferInfo
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::xla::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
{{BUFFER_INFOS_AS_STRING}}
};

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "llvm/Support/TargetSelect.h"
#include "tensorflow/compiler/xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
@ -34,7 +35,7 @@ namespace tensorflow {
namespace tfcompile {
namespace {
using ::tensorflow::cpu_function_runtime::BufferInfo;
using ::xla::cpu_function_runtime::BufferInfo;
void ExpectErrorContains(const Status& status, absl::string_view str) {
EXPECT_NE(Status::OK(), status);
@ -175,14 +176,19 @@ TEST(CodegenTest, Golden) {
fetch->mutable_id()->set_node_name("fetch0");
fetch->set_name("myfetch");
tf2xla::Variable* variable = config.add_variable();
variable->set_node_name("myvar");
variable->set_node_name("myvar_readonly");
variable->mutable_shape()->add_dim()->set_size(1);
variable->set_type(DT_FLOAT);
variable->set_readonly(true);
tf2xla::Variable* variable2 = config.add_variable();
variable2->set_node_name("my/var");
variable2->set_name("myvar2");
variable2->mutable_shape()->add_dim()->set_size(5);
variable2->set_type(DT_INT32);
variable2->set_node_name("myvar");
variable2->mutable_shape()->add_dim()->set_size(1);
variable2->set_type(DT_FLOAT);
tf2xla::Variable* variable3 = config.add_variable();
variable3->set_node_name("my/var");
variable3->set_name("myvar2");
variable3->mutable_shape()->add_dim()->set_size(5);
variable3->set_type(DT_INT32);
CompileResult compile_result;
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
{},
@ -198,6 +204,7 @@ TEST(CodegenTest, Golden) {
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
xla::ShapeUtil::MakeShape(xla::F32, {1}),
xla::ShapeUtil::MakeShape(xla::F32, {1}),
xla::ShapeUtil::MakeShape(xla::S32, {5}),
},
xla::ShapeUtil::MakeTupleShape({

View File

@ -52,7 +52,7 @@ namespace bar {
// is guaranteed that no thread may call a non-const method.
//
// The logical function signature is:
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5])
//
// Memory stats:
// arg bytes total: 104
@ -91,7 +91,8 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
return *kStaticData;
}
MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
MyClass(AllocMode alloc_mode =
AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
MyClass(const MyClass&) = delete;
@ -214,71 +215,97 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
result_data(0)))[dim0][dim1];
}
// Methods for managing variable buffers. Buffers are in row-major order. The
// input and output buffers may or may not be identical.
// Methods for managing variable buffers. Buffers are in row-major order.
//
// For read-write variables we generate the following methods:
//
// void set_var_X_data(T* data)
// Sets the buffer for variable X.
// Sets the buffer for variable X. Must be called before Run if the
// allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
//
// T* var_X_data()
// Returns the buffer of type T for variable X.
// Returns the buffer of type T for variable X. If the allocation mode is
// RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
// buffer passed to set_var_X_data.
//
// T& var_X(...dim indices...)
// Returns a reference to the value of type T for variable X,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
// For readonly variables we generate the same set of methods, except that we
// use `const T` instead of `T`. We use `const T` to avoid erasing the
// constness of the buffer passed to `set_var_X_data` but the underlying
// buffer is not const (and thus the const can be safely const-cast'ed away)
// unless `set_var_X_data` is called with a pointer to constant storage.
void set_var_myvar_data(float* data) {
void set_var_myvar_readonly_data(const float* data) {
set_arg_data(2, data);
}
void set_var_myvar2_data(tensorflow::int32* data) {
set_arg_data(3, data);
const float* var_myvar_readonly_data() {
return static_cast<const float*>(arg_data(2));
}
const float& var_myvar_readonly() {
return (*static_cast<const float(*)[1]>(
arg_data(2)))[0];
}
const float* var_myvar_readonly_data() const {
return static_cast<const float*>(arg_data(2));
}
const float& var_myvar_readonly() const {
return (*static_cast<const float(*)[1]>(
arg_data(2)))[0];
}
void set_var_myvar_data(float* data) {
set_arg_data(3, data);
}
float* var_myvar_data() {
return static_cast<float*>(result_data(1));
return static_cast<float*>(arg_data(3));
}
float& var_myvar() {
return (*static_cast<float(*)[1]>(
result_data(1)))[0];
arg_data(3)))[0];
}
const float* var_myvar_data() const {
return static_cast<const float*>(result_data(1));
return static_cast<const float*>(arg_data(3));
}
const float& var_myvar() const {
return (*static_cast<const float(*)[1]>(
result_data(1)))[0];
arg_data(3)))[0];
}
void set_var_myvar2_data(tensorflow::int32* data) {
set_arg_data(4, data);
}
tensorflow::int32* var_myvar2_data() {
return static_cast<tensorflow::int32*>(result_data(2));
return static_cast<tensorflow::int32*>(arg_data(4));
}
tensorflow::int32& var_myvar2(size_t dim0) {
return (*static_cast<tensorflow::int32(*)[5]>(
result_data(2)))[dim0];
arg_data(4)))[dim0];
}
const tensorflow::int32* var_myvar2_data() const {
return static_cast<const tensorflow::int32*>(result_data(2));
return static_cast<const tensorflow::int32*>(arg_data(4));
}
const tensorflow::int32& var_myvar2(size_t dim0) const {
return (*static_cast<const tensorflow::int32(*)[5]>(
result_data(2)))[dim0];
arg_data(4)))[dim0];
}
private:
// Number of buffers for the compiled computation.
static constexpr size_t kNumBuffers = 6;
static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::tensorflow::cpu_function_runtime::BufferInfo
static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
static const ::xla::cpu_function_runtime::BufferInfo
kBufferInfos[kNumBuffers] = {
::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
};
return kBufferInfos;
}
@ -309,7 +336,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
static const xla::ProgramShapeProto* StaticProgramShape() {
static const xla::ProgramShapeProto* kShape = []() {
xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 132);
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 149);
return proto;
}();
return kShape;

View File

@ -36,6 +36,7 @@ py_binary(
name = "make_test_graphs",
testonly = 1,
srcs = ["make_test_graphs.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",

View File

@ -159,10 +159,11 @@ def tfvariable(_):
def tfvariable_sequential_updates(_):
x = variables.Variable(1.0, name='x')
y = variables.Variable(1.0, name='y')
updates = control_flow_ops.no_op()
for _ in range(3):
with ops.control_dependencies([updates]):
x_val = x.read_value() + 1.0
x_val = x.read_value() + y
updates = x.assign_sub(0.1 * x_val)
array_ops.identity(updates, name='result')

View File

@ -7,3 +7,9 @@ variable {
node_name: "x"
type: DT_FLOAT
}
variable {
node_name: "y"
type: DT_FLOAT
readonly: true
}

View File

@ -83,7 +83,8 @@ TEST(TFCompileTest, Add) {
// Run tests that use set_argN_data separately, to avoid accidentally re-using
// non-existent buffers.
TEST(TFCompileTest, Add_SetArg) {
AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
AddComp add(
XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
int32 arg_x = 10;
int32 arg_y = 32;
@ -296,7 +297,7 @@ TEST(TFCompileTest, MatMul2_SetArg) {
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
foo::bar::MatMulComp matmul(
foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
matmul.set_thread_pool(&device);
// Test using the set_argN_data() methods.
@ -502,20 +503,50 @@ TEST(TFCompileTest, VariableSequentialUpdates) {
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
// This implements the recursion:
// x[0] = 1.0
// x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
// x[0] = 2.0
// x[n+1] = x[n] - 0.1*(x[n-1] + y)
VariableSequentialUpdatesComp fn;
float x = 1;
fn.set_var_x_data(&x);
fn.var_x() = 2;
*const_cast<float*>(fn.var_y_data()) = 1;
fn.set_thread_pool(&device);
// First calculate x[3]
fn.Run();
EXPECT_NEAR(x, 0.458f, 1e-6);
EXPECT_NEAR(fn.var_x(), 1.187f, 1e-6);
const float y = 1;
fn.set_var_y_data(&y);
// Now const_cast<float*>(fn.var_y_data()) is not longer legal since we've set
// the buffer to point to a constant location.
// Then calculate x[6]
fn.Run();
EXPECT_NEAR(x, 0.062882f, 1e-6);
EXPECT_NEAR(fn.var_x(), 0.594322f, 1e-6);
}
TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
Eigen::ThreadPool tp(1);
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
// This implements the recursion:
// x[0] = 2.0
// x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
VariableSequentialUpdatesComp fn(
XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
float x = 2;
float y = 1;
fn.set_var_x_data(&x);
fn.set_var_y_data(&y);
fn.set_thread_pool(&device);
// First calculate x[3]
fn.Run();
EXPECT_NEAR(x, 1.187f, 1e-6);
// Then calculate x[6]
fn.Run();
EXPECT_NEAR(x, 0.594322f, 1e-6);
}
TEST(TFCompileTest, AssertEqAndReturnDiff) {

View File

@ -163,7 +163,10 @@ def tf_library(
header_file = name + ".h"
metadata_object_file = name + "_tfcompile_metadata.o"
function_object_file = name + "_tfcompile_function.o"
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
# The XLA backends morph kernal name prefix __ that is not in the form of
# __xla_.
ep = ("__xla_" + native.package_name() + "__" + name).replace("/", "_")
if type(tfcompile_flags) == type(""):
flags = tfcompile_flags
else:
@ -171,6 +174,20 @@ def tf_library(
"'" + arg.replace("'", "'\\''") + "'"
for arg in (tfcompile_flags or [])
])
# Do this before we append the `select` into `flags`, because doing so
# transforms `flags` into a variable of type `select`, and we can't call
# `find` on such an object.
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
# Pass --target_cpu=haswell to tfcompile if compiling for Haswell (bazel
# build --cpu=haswell). We put it at the beginning of the flags list so
# that tfcompile_flags can override if if desired.
flags = select({
"//tools/target_cpu:haswell": "--target_cpu=haswell ",
"//conditions:default": "",
}) + flags
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
else:
@ -248,7 +265,6 @@ def tf_library(
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
native.cc_library(
name = name,
srcs = [function_object_file, metadata_object_file],

View File

@ -17,15 +17,14 @@ package_group(
package(
default_visibility = [
":internal",
# BEGIN-GOOGLE-INTERNAL
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
# END-GOOGLE-INTERNAL
],
)
# NB! Removing the cc_header_only_library import breaks the OSS build since
# copybara injects some build rules that use it.
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
# Target that bundles up the XLA CPU and GPU JIT devices.
@ -78,10 +77,10 @@ cc_library(
srcs = ["xla_cpu_device.cc"],
visibility = [":friends"],
deps = [
":create_xla_launch_op", # buildcleaner: keep
":flags",
":jit_compilation_passes",
":xla_device",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -98,9 +97,9 @@ cc_library(
srcs = ["xla_gpu_device.cc"],
visibility = [":friends"],
deps = [
":create_xla_launch_op", # buildcleaner: keep
":jit_compilation_passes",
":xla_device",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -168,7 +167,6 @@ cc_library(
":xla_tensor",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -199,6 +197,7 @@ cc_library(
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:host_constant_op",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
@ -212,6 +211,8 @@ cc_library(
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:optional_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor/platform",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
@ -222,6 +223,7 @@ cc_library(
name = "shape_inference_helpers",
srcs = ["shape_inference_helpers.cc"],
hdrs = ["shape_inference_helpers.h"],
visibility = [":friends"],
deps = ["//tensorflow/core:graph"],
)
@ -236,6 +238,7 @@ cc_library(
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
@ -254,6 +257,11 @@ cc_library(
name = "xla_launch_util",
srcs = ["xla_launch_util.cc"],
hdrs = ["xla_launch_util.h"],
# TODO(skyewm): remove this once XlaAllocator is factored out.
visibility = [
":internal",
"//tensorflow/compiler/xla/python:__pkg__",
],
deps = [
":common",
":xla_compilation_cache",
@ -263,7 +271,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -271,6 +278,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
@ -283,7 +291,6 @@ cc_library(
hdrs = ["xla_compilation_cache.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
@ -326,10 +333,10 @@ cc_library(
)
cc_library(
name = "create_xla_launch_op",
name = "xla_kernel_creator",
srcs = [
"create_xla_launch_op.cc",
"create_xla_launch_op.h",
"xla_kernel_creator.cc",
"xla_kernel_creator.h",
],
deps = [
":common",
@ -346,13 +353,13 @@ cc_library(
)
tf_cc_test(
name = "create_xla_launch_op_test",
name = "xla_kernel_creator_test",
srcs = [
"create_xla_launch_op.h",
"create_xla_launch_op_test.cc",
"xla_kernel_creator.h",
"xla_kernel_creator_test.cc",
],
deps = [
":create_xla_launch_op",
":xla_kernel_creator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -370,6 +377,7 @@ cc_library(
srcs = ["resource_operation_safety_analysis.cc"],
hdrs = ["resource_operation_safety_analysis.h"],
deps = [
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/core:framework",
@ -417,7 +425,6 @@ cc_library(
hdrs = ["shape_inference.h"],
deps = [
":shape_inference_helpers",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@ -467,6 +474,9 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
@ -498,6 +508,7 @@ cc_library(
"encapsulate_xla_computations_pass.cc",
"extract_outside_compilation_pass.cc",
"increase_dynamism_for_auto_jit_pass.cc",
"introduce_floating_point_jitter_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
@ -510,24 +521,28 @@ cc_library(
"encapsulate_xla_computations_pass.h",
"extract_outside_compilation_pass.h",
"increase_dynamism_for_auto_jit_pass.h",
"introduce_floating_point_jitter_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
],
deps = [
"compilability_check_util",
":common",
":device_util",
":encapsulate_util",
":flags",
":resource_operation_safety_analysis",
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:tf2xla_util",
@ -535,6 +550,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
@ -561,19 +577,49 @@ cc_library(
hdrs = ["xla_cluster_util.h"],
deps = [
":flags",
":resource_operation_safety_analysis",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "device_util",
srcs = ["device_util.cc"],
hdrs = ["device_util.h"],
deps = [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "device_util_test",
srcs = ["device_util_test.cc"],
deps = [
":device_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
@ -631,13 +677,15 @@ tf_cc_test(
srcs = [
"build_xla_ops_pass_test.cc",
"clone_constants_for_better_clustering_test.cc",
"compilation_passes_test_main.cc",
"encapsulate_subgraphs_pass_test.cc",
"encapsulate_xla_computations_pass_test.cc",
"extract_outside_compilation_pass_test.cc",
"increase_dynamism_for_auto_jit_pass_test.cc",
"introduce_floating_point_jitter_pass_internal.h",
"introduce_floating_point_jitter_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
"rearrange_function_argument_pass_test.cc",
],
deps = [
":common",
@ -658,6 +706,7 @@ tf_cc_test(
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
@ -677,6 +726,7 @@ tf_cc_test(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
@ -696,6 +746,7 @@ tf_cc_test(
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -708,43 +759,6 @@ tf_cc_test(
],
)
cc_library(
name = "xla_fusion_optimizer",
srcs = ["xla_fusion_optimizer.cc"],
hdrs = ["xla_fusion_optimizer.h"],
visibility = ["//visibility:public"],
deps = [
":common",
":compilation_passes",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "xla_fusion_optimizer_test",
srcs = ["xla_fusion_optimizer_test.cc"],
deps = [
":common",
":xla_cluster_util",
":xla_fusion_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler/utils:grappler_test",
],
)
cc_library(
name = "node_matchers",
testonly = True,
@ -776,6 +790,34 @@ tf_cc_test(
],
)
cc_library(
name = "compilability_check_util",
srcs = ["compilability_check_util.cc"],
hdrs = ["compilability_check_util.h"],
deps = [
":common",
":device_util",
":flags",
":resource_operation_safety_analysis",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
tf_custom_op_py_library(
name = "xla_ops_py",
kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@ -23,12 +24,13 @@ limitations under the License.
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/control_flow_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/logging_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
@ -42,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace {
@ -74,7 +77,8 @@ Operation DataToControl(const Scope& scope, Output data) {
// Replaces each outgoing edge from `old_node` with a merge node that merges in
// the corresponding output from `new_node`.
void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node) {
void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
bool insert_print_nodes) {
if (!s.status().ok()) {
return;
}
@ -91,7 +95,21 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node) {
if (merged_output.node() == nullptr) {
ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)),
{Output(old_node, oidx), Output(new_node, oidx)});
merged_output = merged_outputs[oidx] = merge_op.output;
if (insert_print_nodes) {
string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0";
ops::Print print_op(s.WithOpName(absl::StrCat("print_", oidx))
.WithDevice(cpu_device)
.WithAssignedDevice(cpu_device),
merge_op.output, {merge_op.output},
ops::Print::Attrs{}
.Message(absl::StrCat("output ", oidx, " from ",
old_node->name(), " is "))
.FirstN(1000)
.Summarize(-1));
merged_output = merged_outputs[oidx] = print_op;
} else {
merged_output = merged_outputs[oidx] = merge_op.output;
}
}
Node* dst = e->dst();
@ -215,14 +233,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) {
}
// Returns true (into `result`) if a node placed on `device` must be compiled.
Status DeviceRequiresCompilation(const string& device, bool* result) {
DeviceType device_type("");
TF_RETURN_IF_ERROR(DeviceToDeviceType(device, &device_type));
const XlaOpRegistry::DeviceRegistration* registration = nullptr;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
return errors::Internal("Could not find compilation device ",
device_type.type());
}
Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache,
jit::DeviceId device, bool* result) {
const XlaOpRegistry::DeviceRegistration* registration =
device_info_cache.GetCompilationDevice(device);
*result = registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways;
return Status::OK();
@ -275,17 +289,20 @@ Status ReplaceFunctionCallWithPartionedCall(
return Status::OK();
}
Status InferDeviceForCluster(Node* n, const string& function_name,
const FunctionLibraryDefinition& flib_def,
string* result) {
xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
jit::DeviceInfoCache* device_info_cache, Node* n,
const string& function_name, const FunctionLibraryDefinition& flib_def) {
const FunctionDef* func_def = flib_def.Find(function_name);
TF_RET_CHECK(func_def) << "Could not find " << function_name;
std::set<string> device_names;
jit::DeviceSet device_set;
for (const NodeDef& ndef : func_def->node_def()) {
VLOG(3) << ndef.DebugString();
if (!ndef.device().empty()) {
device_names.insert(ndef.device());
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
device_info_cache->GetIdFor(ndef.device()));
device_set.Insert(device_id);
}
}
@ -293,41 +310,47 @@ Status InferDeviceForCluster(Node* n, const string& function_name,
// TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
// assignment when constant folding. We should fix EncapsulateSubgraphsPass
// instead.
device_names.insert(n->assigned_device_name());
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
device_info_cache->GetIdFor(n->assigned_device_name()));
device_set.Insert(device_id);
}
std::vector<string> device_names_vector;
absl::c_copy(device_names, std::back_inserter(device_names_vector));
Status s = PickDeviceForXla(device_names_vector, true, result);
if (s.ok()) {
VLOG(2) << "For " << function_name << " PickDeviceForXla("
<< absl::StrJoin(device_names_vector, ", ") << ") -> " << *result;
}
return s;
TF_ASSIGN_OR_RETURN(jit::DeviceId result,
PickDeviceForXla(*device_info_cache, device_set,
/*allow_mixing_unknown_and_cpu=*/true));
VLOG(2) << "For " << function_name << " PickDeviceForXla("
<< device_info_cache->DebugString(device_set) << ") -> "
<< device_info_cache->GetNameFor(result);
return result;
}
Status ReplaceNodeWithXlaCompileAndXlaRun(
jit::DeviceInfoCache* device_info_cache,
const GraphOptimizationPassOptions& options,
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
Graph* g, Node* n) {
bool insert_print_nodes, Graph* g, Node* n) {
XlaClusterInfo cluster_info;
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
string device;
TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(),
flib_def, &device));
TF_ASSIGN_OR_RETURN(
jit::DeviceId device,
InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(),
flib_def));
bool requires_compilation;
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation));
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device,
&requires_compilation));
if (!lazy_compilation_enabled) {
requires_compilation = true;
}
string device_name_str = string(device_info_cache->GetNameFor(device));
Status status;
Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
.NewSubScope(n->name())
.WithDevice(n->requested_device())
.WithAssignedDevice(device);
.WithAssignedDevice(device_name_str);
ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
/*constants=*/cluster_info.constant_inputs,
@ -378,7 +401,8 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
/*new_node=*/xla_run.operation.node());
MergeOutgoingDataEdges(root, /*old_node=*/n,
/*new_node=*/xla_run.operation.node());
/*new_node=*/xla_run.operation.node(),
insert_print_nodes);
TF_RETURN_IF_ERROR(root.status());
@ -418,15 +442,20 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
bool lazy_compilation_enabled =
enable_lazy_compilation_
? *enable_lazy_compilation_
: GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation;
: GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
bool insert_print_nodes =
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
jit::DeviceInfoCache device_info_cache;
for (Node* n : xla_compiled_kernels) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
options, *options.flib_def, lazy_compilation_enabled, graph, n));
&device_info_cache, options, *options.flib_def,
lazy_compilation_enabled, insert_print_nodes, graph, n));
}
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
}
return Status::OK();

View File

@ -122,7 +122,7 @@ Status CloneConstantsForBetterClusteringPass::CloneSmallHostConstantInputs(
Status CloneConstantsForBetterClusteringPass::Run(
const GraphOptimizationPassOptions& options) {
if (GetGlobalJitLevel(options) == OptimizerOptions::OFF) {
if (GetGlobalJitLevelForGraph(options) == OptimizerOptions::OFF) {
return Status::OK();
}

View File

@ -0,0 +1,277 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include <atomic>
#include <deque>
#include <limits>
#include <unordered_map>
#include <unordered_set>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace {
bool HasResourceInput(const Node& node) {
return absl::c_count(node.input_types(), DT_RESOURCE) != 0;
}
} // anonymous namespace
bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
// is really a kind of function call and will be handled by
// IsCompilableCall().
if (node.type_string() == "SymbolicGradient") return false;
if (node.type_string() == "Const") {
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
// registered Const KernelDef says that it does, to support no-op Assert for
// tfcompile.
const AttrValue* attr = node.attrs().Find("dtype");
if (attr != nullptr && attr->type() == DT_STRING) {
return false;
}
}
// XLA does not offer guaranteed aliasing between the input and output of the
// XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
// such nodes out of XLA clusters.
if (HasForwardedRefInput(node)) {
VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
return false;
}
return FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr).ok();
}
// Tests whether 'while_node' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
bool RecursiveCompilabilityChecker::IsCompilableWhile(
const Node& while_node, int depth, FunctionLibraryRuntime* lib_runtime) {
const NameAttrList* name_attr;
NodeDef call;
Status status;
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
if (!status.ok()) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": missing 'cond' attribute on While node.";
return false;
}
const string cond_func = name_attr->name();
call.set_name("while_cond");
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, depth + 1, lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop condition: " << cond_func;
return false;
}
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
if (!status.ok()) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": missing 'body' attribute on While node.";
return false;
}
const string body_func = name_attr->name();
call.set_name("while_body");
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, depth + 1, lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop body: " << body_func;
return false;
}
return true;
}
// Tests whether 'call_def' is a call to a completely compilable function.
// Every operator in the function must be compilable for a function to be
// compilable.
bool RecursiveCompilabilityChecker::IsCompilableCall(
const NodeDef& call_def, int depth, FunctionLibraryRuntime* lib_runtime) {
if (depth > kMaxRecursionDepth) {
VLOG(2) << "Rejecting " << call_def.op()
<< ": function depth limit exceeded.";
return false;
}
FunctionLibraryRuntime::Handle handle;
Status status = InstantiateFunctionCall(call_def, lib_runtime, &handle);
if (!status.ok()) {
VLOG(2) << "Rejecting " << call_def.DebugString()
<< ": could not instantiate: " << status;
return false;
}
auto release_handle_on_return = gtl::MakeCleanup(
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
for (Node* node : fbody->graph->op_nodes()) {
if (!IsCompilableNode(*node, depth + 1, lib_runtime)) {
return false;
}
}
return true;
}
bool LogNotCompilableAndReturn(const Node& node,
absl::string_view reason = "") {
VLOG(3) << "Not clustering " << node.name() << " (op " << node.type_string()
<< ")" << (reason.empty() ? "" : ": ") << reason;
return false;
}
bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) {
// b/127344411: SelfAdjointEigV2 and Svd precision issues.
return node.type_string() == "SelfAdjointEigV2" ||
node.type_string() == "Svd";
}
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) {
// b/128001705: SelfAdjointEigV2 and Svd performance issues.
return node.type_string() == "SelfAdjointEigV2" ||
node.type_string() == "Svd" || node.type_string() == "Qr";
}
bool RecursiveCompilabilityChecker::IsCompilableNode(
const Node& node, int depth, FunctionLibraryRuntime* lib_runtime) {
if (node.IsSource() || node.IsSink()) {
return LogNotCompilableAndReturn(node, "source or sink node");
}
// _Arg nodes in a top-level function represent feeds and _Retval nodes in a
// top-level function represent fetches.
if (depth == 0 &&
(node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
return LogNotCompilableAndReturn(node, "depth is 0");
}
if (node.attrs().Find("_scoped_allocator") ||
node.attrs().Find("_forward_from")) {
// TODO(b/128858118): XLA does not support _scoped_allocator and
// _forward_from.
return LogNotCompilableAndReturn(
node, "_scoped_allocator or _forward_from attribute");
}
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
if (!IsCompilableCall(node.def(), depth + 1, lib_runtime)) {
return LogNotCompilableAndReturn(node, "unsupported function");
}
} else if (!HasXLAKernel(node)) {
return LogNotCompilableAndReturn(node, "unsupported op");
}
if (node.type_string() == "While" &&
!IsCompilableWhile(node, depth + 1, lib_runtime)) {
return LogNotCompilableAndReturn(node, "unsupported while");
}
if (!op_filter_.allow_stateful_rng_ops &&
IsStatefulRandomOp(node.type_string())) {
return LogNotCompilableAndReturn(node, "stateful random op");
}
if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
return LogNotCompilableAndReturn(node);
}
if (!op_filter_.allow_eliding_assert_and_checknumerics_ops &&
IsAssertOrCheckNumerics(node.type_string())) {
return LogNotCompilableAndReturn(node, "Assert or CheckNumerics");
}
if (!op_filter_.allow_ops_producing_or_consuming_variant &&
OpProducesOrConsumesVariant(node)) {
return LogNotCompilableAndReturn(node, "DT_VARIANT producer/consumer");
}
if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
return LogNotCompilableAndReturn(node, "Stack op");
}
if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
return LogNotCompilableAndReturn(node, "TensorArray op");
}
if (!op_filter_.allow_resource_ops_in_called_functions && depth > 0 &&
HasResourceInput(node)) {
return LogNotCompilableAndReturn(node,
"resource variable op in called function");
}
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) {
return LogNotCompilableAndReturn(node, "operation with correctness issues");
}
if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) {
return LogNotCompilableAndReturn(node, "slow operation");
}
return true;
}
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
const XlaOpRegistry::DeviceRegistration& registration) {
RecursiveCompilabilityChecker::OperationFilter op_filter;
op_filter.allow_resource_ops_in_called_functions =
registration.cluster_resource_variable_ops_unsafely;
op_filter.allow_stack_ops = registration.cluster_stack_ops;
op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops;
op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops;
op_filter.allow_control_trigger = registration.cluster_control_trigger;
op_filter.allow_eliding_assert_and_checknumerics_ops =
registration.elide_assert_and_checknumerics;
op_filter.allow_ops_producing_or_consuming_variant =
registration.cluster_variant_ops;
op_filter.allow_slow_and_inaccurate_ops =
registration.cluster_slow_and_inaccurate_ops;
return op_filter;
}
} // namespace tensorflow

View File

@ -0,0 +1,175 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
#include "absl/algorithm/container.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
// Checks whether a TF node can be compiled or not. "Recursive" as in for call
// and functional while nodes it recursively checks whether the callee functions
// can be compiled.
class RecursiveCompilabilityChecker {
public:
// Aggregates information about what kinds of ops are allowed.
struct OperationFilter { // TODO(lzr): Add AllowEverything() helper.
// Whether resource variable ops are allowed are allowed in callees. We do
// not allow resource variable ops in called functions (either as direct TF
// calls or as higher order control flow ops) because we do not yet model
// their memory effects in jit/resource_variable_safety_analysis.
bool allow_resource_ops_in_called_functions;
// Whether Stack operations are allowed. We avoid auto-clustering Stack
// operations in general because we do not support snapshotting them.
//
// TODO(b/112837194): This restriction can be lifted with some work.
bool allow_stack_ops;
// Whether TensorArray operations are allowed. We avoid auto-clustering
// TensorArray operations in general because we do not support snapshotting
// them.
//
// TODO(b/112837194): This restriction can be lifted with some work.
bool allow_tensor_array_ops;
// Whether stateful RNG ops are allowed. XLA's RNG does not have the same
// seeding behavior as TensorFlow's RNG (b/34749654). So we avoid
// auto-clustering stateful RNG ops.
bool allow_stateful_rng_ops;
// TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound
// to cluster ControlTrigger because of how we use deadness analysis.
bool allow_control_trigger;
// Whether it is okay to "cluster" Assert and CheckNumerics by simply
// removing them (they're not removed during clustering, but their
// XlaOpKernel is a no-op kernel). We avoid auto-clustering these ops so
// that the user is not surprised when XLA is implicitly enabled. If the
// user explicitly specifies to use XLA, it is fine to resort to a dummy
// implementation. Currently Assert and CheckNumerics ops have dummy XLA
// implementations.
bool allow_eliding_assert_and_checknumerics_ops;
// Whether ops that produce or consume DT_VARIANT values are allowed. We
// don't auto-cluster these ops because we don't yet support live-in or
// live-out DT_VARIANT values.
bool allow_ops_producing_or_consuming_variant;
// Whether ops known to be slow or to have correctness issues should be
// auto-clustered.
bool allow_slow_and_inaccurate_ops;
};
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
const DeviceType* jit_device_type)
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
// Returns true if `node` can be compiled by XLA.
bool IsCompilableNode(const Node& node, FunctionLibraryRuntime* lib_runtime) {
return IsCompilableNode(node, /*depth=*/0, lib_runtime);
}
// Returns true if `call_def` can be compiled by XLA. It is assumed that
// `call_def` is a call operation.
bool IsCompilableCall(const NodeDef& call_def,
FunctionLibraryRuntime* lib_runtime) {
return IsCompilableCall(call_def, /*depth=*/0, lib_runtime);
}
// Returns true if XLA supports this Op, but we don't want to cluster it (ie:
// due to performance or correctness concerns).
bool OpIsInaccurate(const Node& node);
bool OpIsSlow(const Node& node);
private:
bool IsCompilableNode(const Node& node, int depth,
FunctionLibraryRuntime* lib_runtime);
bool IsCompilableCall(const NodeDef& call_def, int depth,
FunctionLibraryRuntime* lib_runtime);
bool IsCompilableWhile(const Node& while_node, int depth,
FunctionLibraryRuntime* lib_runtime);
bool IsStackOp(const Node& node) {
const XlaResourceOpInfo* op_info =
GetResourceOpInfoForOp(node.type_string());
return op_info && op_info->resource_kind() == XlaResourceKind::kStack;
}
bool IsTensorArrayOp(const Node& node) {
const XlaResourceOpInfo* op_info =
GetResourceOpInfoForOp(node.type_string());
return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray;
}
bool IsAssertOrCheckNumerics(absl::string_view op_name) {
return op_name == "Assert" || op_name == "CheckNumerics";
}
bool IsStatefulRandomOp(absl::string_view op_name) {
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
op_name == "TruncatedNormal" || op_name == "Multinomial";
}
bool OpProducesOrConsumesVariant(const Node& node) {
auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
return absl::c_any_of(node.input_types(), is_variant) ||
absl::c_any_of(node.output_types(), is_variant);
}
bool HasXLAKernel(const Node& node);
// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 10;
const OperationFilter& op_filter_;
const DeviceType& jit_device_type_;
};
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
const XlaOpRegistry::DeviceRegistration& registration);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_

View File

@ -38,10 +38,13 @@ GTEST_API_ int main(int real_argc, char** real_argv) {
void operator()(char* ptr) { free(ptr); }
};
std::unique_ptr<char, FreeDeleter> allocated_arg(
std::unique_ptr<char, FreeDeleter> enable_global_jit_arg(
strdup("--tf_xla_cpu_global_jit=true"));
args.push_back(enable_global_jit_arg.get());
args.push_back(allocated_arg.get());
std::unique_ptr<char, FreeDeleter> reduce_min_cluster_size_arg(
strdup("--tf_xla_min_cluster_size=2"));
args.push_back(reduce_min_cluster_size_arg.get());
int argc = args.size();

View File

@ -106,6 +106,8 @@ namespace tensorflow {
namespace {
using se::port::StatusOr;
// Represents a logical predicate, used as described in the algorithm overview
// above.
class Predicate {
@ -369,7 +371,8 @@ class PredicateFactory {
Predicate** predicate) {
TensorId tensor_id(node->name(), output_idx);
bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL;
bool is_boolean_tensor =
BaseType(node->output_type(tensor_id.index())) == DT_BOOL;
TF_RET_CHECK(!must_be_true || is_boolean_tensor);
if (node->type_string() == "Const" && must_be_true) {
@ -698,7 +701,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status Populate();
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
bool HasInputsWithMismatchingDeadness(const Node& node) override;
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
Node* n, int oidx) const override;
void Print() const override;
absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
const;
@ -768,7 +772,8 @@ Status DeadnessAnalysisImpl::GetInputPreds(
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
if (it == predicate_map_.end()) {
GraphCycles graph_cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles));
TF_RETURN_IF_ERROR(
CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
// If we didn't return with an error above then the graph is probably
// fine and we have a bug in deadness analysis.
@ -1112,42 +1117,13 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
return Status::OK();
}
bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
CHECK(!node.IsMerge());
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")";
}
Predicate* pred = nullptr;
for (const Edge* edge : node.in_edges()) {
auto it = predicate_map_.find(InputEdgeToTensorId(edge));
CHECK(it != predicate_map_.end());
if (vlog_) {
VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": "
<< it->second->ToString();
}
// Today we just compare the predicates for equality (with some
// canonicalization/simplification happening before) but we could be more
// sophisticated here if need be. Comparing pointers is sufficient because
// we intern Predicate instances by their content.
if (pred != nullptr && pred != it->second) {
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
<< ") -> true";
}
return true;
}
pred = it->second;
}
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
<< ") -> false";
}
return false;
StatusOr<DeadnessAnalysis::DeadnessPredicate>
DeadnessAnalysisImpl::GetPredicateFor(Node* n, int oidx) const {
auto it = predicate_map_.find(TensorId(n->name(), oidx));
TF_RET_CHECK(it != predicate_map_.end())
<< "could not find " << TensorId(n->name(), oidx).ToString()
<< " in predicate map";
return MakeDeadnessPredicate(it->second);
}
void DeadnessAnalysisImpl::Print() const {
@ -1212,4 +1188,8 @@ Status ComputePredicates(const Graph& graph,
}
} // namespace deadness_analysis_internal
string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const {
return static_cast<Predicate*>(predicate.pred_)->ToString();
}
} // namespace tensorflow

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -43,24 +44,55 @@ namespace tensorflow {
// "liveness" already has other connotations.
class DeadnessAnalysis {
public:
// Returns true if `node` may have some live inputs and some dead inputs.
//
// This is a conservatively correct routine -- if it returns false then `node`
// is guaranteed to not have inputs with mismatching liveness, but not the
// converse.
//
// REQUIRES: node is not a Merge operation.
virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0;
// An opaque representation of a predicate. DeadnessPredicate
// instances that compare equal via operator== represent predicates
// that always evaluate to the same value.
struct DeadnessPredicate {
public:
DeadnessPredicate(const DeadnessPredicate&) = default;
DeadnessPredicate(DeadnessPredicate&&) = default;
DeadnessPredicate& operator=(const DeadnessPredicate&) = default;
DeadnessPredicate& operator=(DeadnessPredicate&&) = default;
bool operator==(const DeadnessPredicate& other) const {
return other.pred_ == pred_;
}
bool operator!=(const DeadnessPredicate& other) const {
return other.pred_ != pred_;
}
private:
explicit DeadnessPredicate(void* pred) : pred_(pred) {}
// This is really a Predicate*, but we don't want to expose that
// implementation detail to our clients. `pred_` has pointer equality so we
// can just compare the pointer in operator== and operator!=.
void* pred_;
friend class DeadnessAnalysis;
};
virtual se::port::StatusOr<DeadnessPredicate> GetPredicateFor(
Node* n, int oidx) const = 0;
// Prints out the internal state of this instance. For debugging purposes
// only.
virtual void Print() const = 0;
virtual ~DeadnessAnalysis();
string DebugString(DeadnessPredicate predicate) const;
// Run the deadness analysis over `graph` and returns an error or a populated
// instance of DeadnessAnalysis in `result`.
static Status Run(const Graph& graph,
std::unique_ptr<DeadnessAnalysis>* result);
protected:
static DeadnessPredicate MakeDeadnessPredicate(void* pred) {
return DeadnessPredicate(pred);
}
};
} // namespace tensorflow

View File

@ -37,6 +37,22 @@ limitations under the License.
namespace tensorflow {
namespace {
se::port::StatusOr<bool> HasInputsWithMismatchingDeadness(
const DeadnessAnalysis& deadness_analysis, const Node& n) {
absl::optional<DeadnessAnalysis::DeadnessPredicate> pred;
for (const Edge* edge : n.in_edges()) {
TF_ASSIGN_OR_RETURN(
DeadnessAnalysis::DeadnessPredicate this_pred,
deadness_analysis.GetPredicateFor(edge->src(), edge->src_output()));
if (pred && *pred != this_pred) {
return true;
}
pred = this_pred;
}
return false;
}
using deadness_analysis_internal::ComputePredicates;
using deadness_analysis_internal::PredicateMapTy;
@ -219,7 +235,10 @@ TEST(DeadnessAnalysisTest, BasicPositive) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, BasicNegative) {
@ -232,7 +251,10 @@ TEST(DeadnessAnalysisTest, BasicNegative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, AndIsCommutative) {
@ -260,11 +282,27 @@ TEST(DeadnessAnalysisTest, AndIsCommutative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
bool has_inputs_with_mismatching_deadness;
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *live0.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *live1.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, AndIsAssociative) {
@ -287,7 +325,10 @@ TEST(DeadnessAnalysisTest, AndIsAssociative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, OrIsCommutative) {
@ -312,11 +353,27 @@ TEST(DeadnessAnalysisTest, OrIsCommutative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
bool has_inputs_with_mismatching_deadness;
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *live0.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *live1.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, OrIsAssociative) {
@ -336,7 +393,10 @@ TEST(DeadnessAnalysisTest, OrIsAssociative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, AndOfOr) {
@ -358,7 +418,10 @@ TEST(DeadnessAnalysisTest, AndOfOr) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add2.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, OrOfAnd) {
@ -382,7 +445,10 @@ TEST(DeadnessAnalysisTest, OrOfAnd) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add2.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
@ -430,7 +496,10 @@ TEST(DeadnessAnalysisTest, AndOrDistributive) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add3.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, Ternary) {
@ -454,7 +523,10 @@ TEST(DeadnessAnalysisTest, Ternary) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, Recv) {
@ -469,7 +541,10 @@ TEST(DeadnessAnalysisTest, Recv) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, HostRecv) {
@ -484,7 +559,10 @@ TEST(DeadnessAnalysisTest, HostRecv) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, Loop) {
@ -505,8 +583,17 @@ TEST(DeadnessAnalysisTest, Loop) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
bool has_inputs_with_mismatching_deadness;
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add1.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
{
PredicateMapTy predicate_map;
@ -544,7 +631,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
{
PredicateMapTy predicate_map;
@ -634,7 +724,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
{
PredicateMapTy predicate_map;
@ -693,7 +786,10 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
{
@ -792,7 +888,10 @@ TEST(DeadnessAnalysisTest, ControlInputs) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, ControlTrigger) {
@ -819,7 +918,10 @@ TEST(DeadnessAnalysisTest, ControlTrigger) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
@ -840,7 +942,10 @@ TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, RecvVsSwitch) {
@ -857,7 +962,10 @@ TEST(DeadnessAnalysisTest, RecvVsSwitch) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *logical_and.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
@ -959,5 +1067,25 @@ TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
}
TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) {
Scope root = Scope::NewRootScope().ExitOnError();
Output condition_ref_var =
ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL);
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var);
Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
FixupSourceAndSinkEdges(root.graph());
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0");
EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0");
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,206 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/device_util.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
namespace jit {
using xla::StatusOr;
void DeviceSet::Insert(DeviceId device_id) {
int word_index = device_id.id() / kWordSize;
int bit_index = device_id.id() % kWordSize;
if (word_index >= storage_.size()) {
storage_.resize(word_index + 1, 0);
}
storage_[word_index] |= (1ull << bit_index);
}
void DeviceSet::UnionWith(const DeviceSet& other) {
if (other.storage_.size() > storage_.size()) {
storage_.resize(other.storage_.size(), 0);
}
for (int i = 0; i < other.storage_.size(); i++) {
storage_[i] |= other.storage_[i];
}
}
bool DeviceSet::IsEmpty() const {
return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
}
xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
TF_RET_CHECK(!name.empty());
auto it = name_to_id_.find(name);
if (it != name_to_id_.end()) {
return it->second;
}
int new_id = names_.size();
names_.push_back(string(name));
id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
DeviceType* device_type = id_to_device_type_.back().get();
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
name_to_id_.emplace(string(name), DeviceId(new_id));
const XlaOpRegistry::DeviceRegistration* compilation_device;
if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
&compilation_device)) {
compilation_device = nullptr;
}
id_to_compilation_device_.push_back(compilation_device);
return DeviceId(new_id);
}
string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
std::vector<string> names;
device_set.ForEach([&](DeviceId device_id) {
names.push_back(string(GetNameFor(device_id)));
return false;
});
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
}
} // namespace jit
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::Internal("Malformed assigned device '", device, "'");
}
*device_type = DeviceType(parsed.type);
return Status::OK();
}
xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
bool failure_to_pick_is_error) {
#define FAILED_TO_PICK_DEVICE(failing_status) \
do { \
if (failure_to_pick_is_error) { \
return failing_status; \
} else { \
return {absl::nullopt}; \
} \
} while (false)
absl::optional<jit::DeviceId> maybe_gpu_device;
absl::optional<jit::DeviceId> maybe_cpu_device;
absl::optional<jit::DeviceId> maybe_unknown_device;
bool multiple_cpu_devices = false;
bool multiple_gpu_devices = false;
bool multiple_unknown_devices = false;
devices.ForEach([&](jit::DeviceId device) {
if (device_info_cache.IsGpu(device)) {
if (maybe_gpu_device) {
multiple_gpu_devices = true;
return false;
}
maybe_gpu_device = device;
} else if (device_info_cache.IsCpu(device)) {
if (maybe_cpu_device) {
multiple_cpu_devices = true;
return false;
}
maybe_cpu_device = device;
} else {
if (maybe_unknown_device) {
multiple_unknown_devices = true;
return false;
}
maybe_unknown_device = device;
}
return true;
});
if (multiple_cpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple CPU devices ", device_info_cache.DebugString(devices)));
}
if (multiple_gpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple GPU devices ", device_info_cache.DebugString(devices)));
}
if (multiple_unknown_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple unknown devices ", device_info_cache.DebugString(devices)));
}
if (maybe_unknown_device && maybe_gpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and GPU devices: ",
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_gpu_device)));
}
if (!allow_mixing_unknown_and_cpu) {
if (maybe_unknown_device && maybe_cpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and CPU devices: ",
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_cpu_device)));
}
}
if (maybe_gpu_device) {
return {*maybe_gpu_device};
} else if (maybe_unknown_device) {
return {*maybe_unknown_device};
} else if (maybe_cpu_device) {
return {*maybe_cpu_device};
}
FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
#undef FAILED_TO_PICK_DEVICE
}
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
TF_ASSIGN_OR_RETURN(absl::optional<jit::DeviceId> device_id,
PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
/*failure_to_pick_is_error=*/true));
return *device_id;
}
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
return PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
/*failure_to_pick_is_error=*/false);
}
} // namespace tensorflow

View File

@ -0,0 +1,211 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
#define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
#include <functional>
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
namespace jit {
// Instances of DeviceId represent TensorFlow devices as integers.
//
// This helps avoid having to manipulate device names as strings when
// auto-clustering.
class DeviceId {
public:
DeviceId(DeviceId&&) = default;
DeviceId(const DeviceId&) = default;
DeviceId& operator=(const DeviceId&) = default;
bool operator==(const DeviceId& other) const { return id() == other.id(); }
bool operator!=(const DeviceId& other) const { return !(*this == other); }
private:
int id_;
explicit DeviceId(int id) : id_(id) {}
int id() const { return id_; }
friend class DeviceInfoCache;
friend class DeviceSet;
};
// A set of DeviceIds, represented as a bitmap.
class DeviceSet {
public:
void Insert(DeviceId device_id);
void UnionWith(const DeviceSet& other);
bool IsEmpty() const;
// Calls `func` on each DeviceId in the set. Stops iterating early if `func`
// return false.
//
// TODO(sanjoy): Change this to take a typed std::function if that's
// performance neutral.
template <typename FnTy>
void ForEach(FnTy func) const {
// This is really a poor man's iterator, we should consider writing a proper
// iterator if this ends up being used widely.
for (int word_index = 0; word_index < storage_.size(); word_index++) {
uint64 word = storage_[word_index];
while (word != 0) {
uint64 only_lowest_bit_set = word & -word;
// The number of trailing zeros in a non-zero word is the index of the
// least significant 1.
int bit_index = ctz_uint64(word);
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
return;
}
word ^= only_lowest_bit_set;
}
}
}
private:
static int ctz_uint64(uint64 x) {
DCHECK_NE(x, 0);
#ifdef __GNUC__
return __builtin_ctzl(x);
#else
int result = 0u;
while ((x & 1u) == 0u) {
x >>= 1;
++result;
}
return result;
#endif
}
absl::InlinedVector<uint64, 1> storage_;
const int kWordSize = 64;
};
// Caches some miscellaneous information about TF devices. Thread compatible.
class DeviceInfoCache {
public:
bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; }
bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; }
absl::string_view GetNameFor(DeviceId device) const {
return names_[device.id()];
}
xla::StatusOr<DeviceId> GetIdFor(absl::string_view name);
using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;
DeviceRegistration* GetCompilationDevice(DeviceId device) const {
return id_to_compilation_device_[device.id()];
}
xla::StatusOr<DeviceRegistration*> GetCompilationDevice(
absl::string_view name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
return GetCompilationDevice(device_id);
}
const DeviceType& GetDeviceTypeFor(DeviceId device) const {
return *id_to_device_type_[device.id()];
}
using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;
xla::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(
absl::string_view device_name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
return std::cref(*id_to_device_type_[device_id.id()]);
}
string DebugString(const DeviceSet& device_set) const;
private:
absl::flat_hash_map<string, DeviceId> name_to_id_;
// These fields are populated for a device in GetIdFor, *before* we give out a
// DeviceId.
std::vector<const XlaOpRegistry::DeviceRegistration*>
id_to_compilation_device_;
std::vector<std::unique_ptr<DeviceType>> id_to_device_type_;
std::vector<string> names_;
std::vector<bool> is_cpu_;
std::vector<bool> is_gpu_;
};
} // namespace jit
// Returns the DeviceType corresponding to 'device'.
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
// Picks the device for which XLA should compile a cluster that contains
// operations placed in devices in `devices`. For instance a cluster that
// contains operations solely placed on the CPU will be compiled into a CPU
// executable by XLA, whereas a cluster that contains operations placed on the
// CPU and also operations placed on the GPU will be compiled into a GPU
// executable.
//
// Returns a non-OK Status if no unambiguous choice of device exists.
//
// We choose the device using the following rules:
//
// - It is an error for `device_names` to contain more than one device of the
// same type.
// - GPU is preferred over CPU.
// - If `allow_mixing_unknown_and_cpu` is true then unknown devices are
// preferred over CPU.
// - XLA devices count as "unrecognized devices".
//
// This set of rules above implicitly assume that XLA:GPU can compile all
// operations in the cluster that XLA:CPU can compile, and if
// `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile
// all operations in the cluster that XLA:CPU can compile.
//
// We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of
// the following things:
//
// - Let MarkForCompilationPass not inject CPU-placed operations into clusters
// that will run on unknown devices (because the unknown XLA backend may not
// support every operation supported by CPU).
// - Let BuildXlaOpsPass successfully infer a compilation device for a cluster
// that contains nodes placed on both the CPU and on unknown devices. In this
// case it is the responsibility of the optimization pass that injected the
// CPU nodes into the cluster to ensure that these nodes can be compiled by
// the unknown XLA backend.
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
// This is like `PickDeviceForXla` except that it returns nullopt (instead of a
// non-OK Status) if no unambiguous choice of device exists.
//
// We return a failing Status for errors unrelated to the device choice
// algorithm itself.
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_

View File

@ -0,0 +1,132 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/device_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> device_names,
string* result) {
jit::DeviceInfoCache cache;
jit::DeviceSet device_set;
for (absl::string_view name : device_names) {
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, cache.GetIdFor(name));
device_set.Insert(device_id);
}
TF_ASSIGN_OR_RETURN(
jit::DeviceId result_id,
PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu));
*result = string(cache.GetNameFor(result_id));
return Status::OK();
}
void CheckPickDeviceResult(absl::string_view expected_result,
bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs) {
string result;
TF_ASSERT_OK(PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result))
<< "inputs = [" << absl::StrJoin(inputs, ", ")
<< "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu
<< ", expected_result=" << expected_result;
EXPECT_EQ(result, expected_result);
}
void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs) {
string result;
EXPECT_FALSE(
PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result).ok());
}
const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0";
const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0";
const char* kXPU0 = "/job:localhost/replica:0/task:0/device:XPU:0";
const char* kYPU0 = "/job:localhost/replica:0/task:0/device:YPU:0";
const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1";
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1";
TEST(PickDeviceForXla, UniqueDevice) {
CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0});
}
TEST(PickDeviceForXla, DeviceOrder) {
CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0});
CheckPickDeviceResult(kGPU0, false, {kCPU0, kGPU0});
CheckPickDeviceResult(kXPU0, true, {kXPU0, kCPU0});
}
TEST(PickDeviceForXla, MultipleUnknownDevices) {
CheckPickDeviceHasError(false, {kXPU0, kYPU0});
}
TEST(PickDeviceForXla, GpuAndUnknown) {
CheckPickDeviceHasError(false, {kGPU0, kXPU1});
}
TEST(PickDeviceForXla, UnknownAndCpu) {
CheckPickDeviceHasError(false, {kXPU0, kCPU1});
}
TEST(PickDeviceForXla, MultipleDevicesOfSameType) {
CheckPickDeviceHasError(true, {kCPU0, kCPU1});
CheckPickDeviceHasError(false, {kCPU0, kCPU1});
CheckPickDeviceHasError(false, {kGPU0, kGPU1});
CheckPickDeviceHasError(false, {kXPU0, kXPU1});
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
}
void SimpleRoundTripTestForDeviceSet(int num_devices) {
jit::DeviceSet device_set;
jit::DeviceInfoCache device_info_cache;
std::vector<string> expected_devices, actual_devices;
for (int i = 0; i < num_devices; i++) {
string device_name =
absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i);
TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id,
device_info_cache.GetIdFor(device_name));
device_set.Insert(device_id);
expected_devices.push_back(device_name);
}
device_set.ForEach([&](jit::DeviceId device_id) {
actual_devices.push_back(string(device_info_cache.GetNameFor(device_id)));
return true;
});
EXPECT_EQ(expected_devices, actual_devices);
}
TEST(DeviceSetTest, SimpleRoundTrip_One) { SimpleRoundTripTestForDeviceSet(1); }
TEST(DeviceSetTest, SimpleRoundTrip_Small) {
SimpleRoundTripTestForDeviceSet(8);
}
TEST(DeviceSetTest, SimpleRoundTrip_Large) {
SimpleRoundTripTestForDeviceSet(800);
}
} // namespace
} // namespace tensorflow

View File

@ -25,12 +25,13 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
@ -50,6 +51,7 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
@ -108,14 +110,14 @@ void MarkGuaranteedConstants(
for (const auto& src_arg : src_arg_pairs) {
srcs.push_back(src_arg.first);
}
ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
// TODO(vinuraja): Doesn't work in the presence of loops.
if (AreAllParentsGuaranteedConst(*n,
guaranteed_const_nodes)) {
guaranteed_const_nodes.insert(n);
}
});
ReverseDFSFrom(
graph, srcs, /*enter=*/nullptr,
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
// TODO(vinuraja): Doesn't work in the presence of loops.
if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
guaranteed_const_nodes.insert(n);
}
});
for (auto& src_arg : src_arg_pairs) {
if (guaranteed_const_nodes.count(src_arg.first) != 0) {
@ -307,6 +309,13 @@ class Encapsulator {
const std::unordered_map<const Node*, Node*>& node_images,
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
// Records the src of the given edge as a control result of the graph.
// Used during graph to function conversion to tie control results to
// the function signature.
Status RecordControlResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images);
// Creates a _Retval node for the src node of edge, and add it to results_,
// if none exists yet. If a new _Retval node is created, also adds the edge
// within the subgraph from the src to the _Retval node.
@ -484,6 +493,11 @@ class Encapsulator {
// Map from source tensor in the input graph to result #.
std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
// Set of node names that are the source of a control output of the
// subgraph. We store strings here so that we can tolerate nodes being
// removed from the graph.
absl::flat_hash_set<string> control_output_nodes_;
// The outside_compilation clusters in this subgraph.
std::unordered_map<string, OutsideCompilationSubgraph>
outside_compilation_subgraphs_;
@ -801,6 +815,15 @@ Status Encapsulator::Subgraph::RecordArg(
return Status::OK();
}
Status Encapsulator::Subgraph::RecordControlResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images) {
Node* src_node = edge->src();
Node* src_image = node_images.at(src_node);
control_output_nodes_.insert(src_image->name());
return Status::OK();
}
Status Encapsulator::Subgraph::RecordResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images) {
@ -1117,17 +1140,22 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
function_def_name_ = name;
FunctionDef fdef;
auto lookup = [this](const Node* node) -> absl::optional<string> {
if (control_output_nodes_.contains(node->name())) {
return absl::make_optional(node->name());
}
return absl::nullopt;
};
// Verify that the graph has well-formed control flow structure.
std::vector<ControlFlowInfo> dummy;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, lookup, &fdef));
if (VLOG_IS_ON(1)) {
VLOG(2) << "Build function def " << name;
dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name),
*graph_, library);
dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name),
fdef);
DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_,
library);
DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef);
}
const FunctionDef* original_fdef = library->Find(name);
@ -1190,11 +1218,10 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
if (VLOG_IS_ON(1)) {
VLOG(2) << "Replace function def " << name;
dump_graph::DumpGraphToFile(
absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
library);
dump_graph::DumpFunctionDefToFile(
absl::StrCat("replace_encapsulate_fdef_", name), fdef);
DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name),
*graph_, library);
DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name),
fdef);
}
TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
@ -1479,9 +1506,10 @@ Status Encapsulator::CopySubgraphEdges(
src_subgraph.RecordOutsideCompilationInputOrControl(
dst_outside_compilation_id, edge);
} else {
// Ignore control edges leaving the subgraph. We will lift them onto the
// enclosing call operators in BuildOutputGraph().
if (!edge->IsControlEdge()) {
if (edge->IsControlEdge()) {
TF_RETURN_IF_ERROR(
src_subgraph.RecordControlResult(edge, node_images));
} else {
TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
}
}
@ -1556,7 +1584,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
if (VLOG_IS_ON(1)) {
// Dump subgraphs.
for (auto& entry : subgraphs_) {
dump_graph::DumpGraphToFile(
DumpGraphToFile(
absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
*entry.second.GetGraph(), library);
}
@ -2320,16 +2348,15 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
return errors::Internal("Failed to find function ", node->type_string(),
" in function library.");
}
FunctionBody* fbody = nullptr;
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, node->attrs(), library,
[library](const string& op, const OpDef** sig) {
return library->LookUpOpDef(op, sig);
},
&fbody));
TF_RETURN_IF_ERROR(
InlineFunctionBody(*library, pruned_graph->get(), node, fbody));
delete fbody;
FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
InlineFunctionBodyOptions inline_opts;
inline_opts.override_device = false;
TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
fbody.get(), inline_opts));
}
return Status::OK();
@ -2394,8 +2421,7 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
&node_images, library));
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference",
*pruned_graph, library);
DumpGraphToFile("pruned_graph_for_shape_inference", *pruned_graph, library);
}
for (auto& subgraph_entry : subgraphs_) {
@ -2471,8 +2497,6 @@ Status EncapsulateSubgraphsInFunctions(
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
FunctionLibraryDefinition* library) {
Status s;
Encapsulator encapsulator(std::move(group_attribute),
std::move(outside_compilation_attribute),
&graph_in);
@ -2526,19 +2550,49 @@ Status EncapsulateSubgraphsPass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "EncapsulateSubgraphsPass::Run";
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
options.flib_def);
DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
options.flib_def);
}
std::unique_ptr<Graph> graph_out;
FunctionLibraryDefinition* const library = options.flib_def;
// Constant folding below might need to run part of the function to compute
// constants. Create an FunctionLibraryRuntime with a single CPU device
// that can run the part of the function.
// NOTE: If this turns out to be slow, we can cache the FLRs keyed by
// `options`.
SessionOptions session_options;
auto* device_count = session_options.config.mutable_device_count();
device_count->insert({"CPU", 1});
std::vector<std::unique_ptr<Device>> devices;
DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
if (!cpu_factory) {
return errors::NotFound(
"CPU Factory not registered. Can't run EncapsulateSubgraphsPass");
}
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
if (devices.empty()) {
return errors::NotFound(
"Failed to create a CPU device for EncapsulateSubgraphsPass");
}
std::unique_ptr<DeviceMgr> device_mgr =
absl::make_unique<DeviceMgr>(std::move(devices));
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env,
new ProcessFunctionLibraryRuntime(device_mgr.get(),
options.session_options->env,
TF_GRAPH_DEF_VERSION, library, opts));
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
if (flr == nullptr) {
return errors::Internal(
"Failed to create and retrieve function library runtime to run "
"constant folding");
}
auto rewrite_subgraph =
[flr](const std::vector<OutputTensor>& arg_source_tensors,
@ -2637,8 +2691,8 @@ Status EncapsulateSubgraphsPass::Run(
"EncapsulateSubgraphsPass failed");
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
options.flib_def);
DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
options.flib_def);
}
*options.graph = std::move(graph_out);

View File

@ -537,8 +537,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
XlaClusterInfo{func, func_name_attrs, xla_computation_node,
std::map<string, int>{}});
}
bool modified;
s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
graph_out.get(), flr, lib_def.get());
graph_out.get(), flr, lib_def.get(), &modified);
if (!s.ok()) return s;
GraphDef graphdef_out;
@ -1105,8 +1106,10 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
{"F"}},
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}},
{"F", "outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
@ -1985,7 +1988,10 @@ TEST(EncapsulateSubgraphsTest,
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}});
@ -2110,7 +2116,10 @@ TEST(EncapsulateSubgraphsTest,
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
@ -2258,8 +2267,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
{}},
absl::Span<const string>(
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O3_host_compute"},
"XlaHostCompute",
{"D:o:0"},
@ -2271,8 +2281,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O3"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
{}}},
absl::Span<const string>({"_xla_token_arg_node",
"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"})}},
{"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"}}},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}});

View File

@ -14,9 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include <algorithm>
#include <iterator>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/shape_inference.h"
@ -24,6 +27,9 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using stream_executor::port::StatusOr;
namespace tensorflow {
@ -333,6 +339,43 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
return Status::OK();
}
StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(
const Graph* g, const string& outside_compilation_attr_name) {
auto cluster_deps = absl::make_unique<
absl::flat_hash_map<string, absl::flat_hash_set<string>>>();
for (const Edge* e : g->edges()) {
auto src_outside_compilation =
GetStringAttr(*e->src(), outside_compilation_attr_name);
auto dst_outside_compilation =
GetStringAttr(*e->dst(), outside_compilation_attr_name);
if (src_outside_compilation && dst_outside_compilation &&
*src_outside_compilation != *dst_outside_compilation) {
auto dst_deps_it = cluster_deps->find(*dst_outside_compilation);
if (dst_deps_it == cluster_deps->end()) {
cluster_deps->insert(std::make_pair(
*dst_outside_compilation,
absl::flat_hash_set<string>({*src_outside_compilation})));
} else {
dst_deps_it->second.insert(*src_outside_compilation);
}
}
}
auto cluster_deps_ordered =
absl::make_unique<absl::flat_hash_map<string, std::vector<string>>>();
for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) {
std::vector<string> ordered_deps(it->second.begin(), it->second.end());
std::sort(ordered_deps.begin(), ordered_deps.end());
cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps));
}
return std::move(cluster_deps_ordered);
}
Status PreprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
// Remove edges from source node to outside compilation nodes, and edges

View File

@ -19,7 +19,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -89,6 +91,15 @@ struct XlaClusterInfo {
const std::map<string, int> host_compute_core;
};
// Finds dependencies between outside compilation clusters, including both data
// dependencies and control dependencies. cluster_deps maps the name name of an
// outside compilation cluster to a set of names of outside compilation clusters
// that it depends on.
stream_executor::port::StatusOr<
std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(
const Graph* g, const string& outside_compilation_attr_name);
// Preprocesses edges within the same XLA cluster. It will perform the following
// operations in order:
//

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.h"
@ -30,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
@ -372,8 +372,8 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
Status EncapsulateXlaComputationsPass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "EncapsulateXlaComputations(): "
<< dump_graph::DumpGraphToFile("encapsulate_xla_computations_before",
**options.graph, options.flib_def);
<< DumpGraphToFile("encapsulate_xla_computations_before",
**options.graph, options.flib_def);
const char* additional_help =
IsCpuGpuCompile(options.graph->get())
@ -383,14 +383,14 @@ Status EncapsulateXlaComputationsPass::Run(
TF_RETURN_WITH_CONTEXT_IF_ERROR(Encapsulate(options.graph, options.flib_def),
additional_help);
VLOG(1) << "EncapsulateXlaComputations() half-way: "
<< dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway",
**options.graph, options.flib_def);
<< DumpGraphToFile("encapsulate_xla_computations_halfway",
**options.graph, options.flib_def);
TF_RETURN_WITH_CONTEXT_IF_ERROR(BuildXlaLaunchOps(options.graph->get()),
additional_help);
VLOG(1) << "EncapsulateXlaComputations() finished: "
<< dump_graph::DumpGraphToFile("encapsulate_xla_computations_after",
**options.graph, options.flib_def);
<< DumpGraphToFile("encapsulate_xla_computations_after",
**options.graph, options.flib_def);
return Status::OK();
}

View File

@ -15,13 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
@ -31,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
@ -287,15 +289,20 @@ absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
return results;
}
string host_compute_node_name(const string& original_oc_name) {
return absl::StrCat("outside_compilation_", original_oc_name,
"_host_compute");
}
// Builds XlaHostCompute NodeDef from the outside compilation call node.
xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
const Node* call_node, const std::map<string, int>& host_compute_core) {
const Node* call_node, const std::map<string, int>& host_compute_core,
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
string original_oc_name;
TF_RETURN_IF_ERROR(GetNodeAttr(
call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
NodeDefBuilder host_compute_builder(
absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"),
"XlaHostCompute");
NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
"XlaHostCompute");
// Copy all attributes.
for (auto attr : call_node->attrs()) {
@ -309,9 +316,25 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
host_compute_builder.Attr("tpu_core", core);
}
// Set input tokens.
host_compute_builder.Attr(kXlaTokenInputNodesAttrName,
std::vector<string>{kXlaTokenArgNodeName});
// Set input tokens and other outside compilation clusters that current
// cluster depends in `kXlaTokenArgNodeName`. This is needed because when
// outside compilation subgraphs are encapsulated and moved to host graph,
// control/data edges between them will only be reflected in host graph.
// From XLA's perspective, two originally dependent clusters are no longer
// connected, which makes them look like they can be scheduled for execution
// in arbitrary order even though in fact they must be executed in order
// according to their host-side graph dependency. This can cause deadlock.
// Therefore, we hint XLA what the correct ordering of these clusters should
// be to avoid deadlocks.
std::vector<string> xla_token_input_nodes;
xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
auto cluster_deps_it = cluster_deps.find(original_oc_name);
if (cluster_deps_it != cluster_deps.end()) {
for (auto dep : cluster_deps_it->second) {
xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
}
}
host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
// Populate inputs.
std::vector<DataType> input_dtypes;
@ -370,8 +393,9 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
// Replace outside compilation function call node with XlaHostCompute node.
// If the function call node has no input/output edges, we will just remove it
// and not create a XlaHostCompute node.
Status ReplaceOrRemoveOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core) {
xla::StatusOr<Node*> ReplaceOrRemoveOutsideCompilationCallNode(
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
// If the function call node has no input/output edges, just remove it.
bool has_edge = false;
for (auto e : call_node->in_edges()) {
@ -389,17 +413,18 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
if (!has_edge) {
VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString();
g->RemoveNode(call_node);
return Status::OK();
return nullptr;
}
// Build XlaHostCompute NodeDef.
TF_ASSIGN_OR_RETURN(NodeDef node_def,
BuildXlaHostComputeNodeDef(call_node, host_compute_core));
TF_ASSIGN_OR_RETURN(
NodeDef node_def,
BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
TF_ASSIGN_OR_RETURN(Node * host_compute_node,
ReplaceNode(g, call_node, node_def));
VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
return Status::OK();
return host_compute_node;
}
// Resets "device_ordinal" attr to placeholder value for related nodes
@ -493,14 +518,9 @@ Status ConstructHostGraph(
device_ordinal_attr.set_i(0);
protobuf::Map<string, AttrValue> attrs;
attrs["device_ordinal"] = device_ordinal_attr;
FunctionBody* host_fbody = nullptr;
std::unique_ptr<FunctionBody> host_fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fld->Find(host_func), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
&host_fbody));
std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody);
*fld->Find(host_func), AttrSlice(&attrs), fld, &host_fbody));
// We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
// reachable from sink node so all nodes will be copied.
@ -581,10 +601,9 @@ Status ConstructHostGraph(
&host_graph, outside_compilation_attr_name));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("extract_outside_compilation_host_graph_for_",
xla_cluster_name),
host_graph, fld);
DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_",
xla_cluster_name),
host_graph, fld);
}
FunctionDef host_graph_fdef;
@ -605,7 +624,8 @@ Status ConstructHostGraph(
Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
FunctionLibraryDefinition* fld,
const string& host_graph_func_name,
Node* xla_computation_node) {
Node* xla_computation_node,
Node* pivot_node) {
// Temporarily use "0" as "device_ordinal". It will be rewritten with the
// correct value in a later pass. We cannot just use placeholder value here
// because FunctionDef instantiation does not allow placeholder value for
@ -614,14 +634,9 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
device_ordinal_attr.set_i(0);
protobuf::Map<string, AttrValue> attrs;
attrs["device_ordinal"] = device_ordinal_attr;
FunctionBody* fbody = nullptr;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fld->Find(host_graph_func_name), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
&fbody));
std::unique_ptr<FunctionBody> fbody_deleter(fbody);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(host_graph_func_name),
AttrSlice(&attrs), fld, &fbody));
Graph* host_graph = fbody->graph;
// We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
@ -631,7 +646,11 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
// Copy all nodes.
std::map<const Node*, Node*> node_map;
node_map[host_graph->source_node()] = main_graph->source_node();
if (pivot_node) {
node_map[host_graph->source_node()] = pivot_node;
} else {
node_map[host_graph->source_node()] = main_graph->source_node();
}
node_map[host_graph->sink_node()] = main_graph->sink_node();
Status s = Status::OK();
auto copy_node_fn = [&](const Node* n) {
@ -684,21 +703,16 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
// 2) Remove control edges.
// 3) Prune nodes that are not useful for shape inference.
Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
Graph* host_graph,
Graph* host_graph, Node* pivot_node,
FunctionLibraryDefinition* fld) {
// Use "0" as "device_ordinal". It does not matter for shape inference.
AttrValue device_ordinal_attr;
device_ordinal_attr.set_i(0);
protobuf::Map<string, AttrValue> attrs;
attrs["device_ordinal"] = device_ordinal_attr;
FunctionBody* fbody = nullptr;
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
&fbody));
std::unique_ptr<FunctionBody> fbody_deleter(fbody);
*fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, &fbody));
Graph* g = fbody->graph;
// Find SendFromHost node.
@ -733,41 +747,45 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
for (Node* n : nodes) {
g->RemoveNode(n);
}
std::map<const Node*, Node*> node_map;
node_map[host_graph->source_node()] = g->source_node();
Status s;
auto copy_node_fn = [&](const Node* n) {
if (!s.ok()) {
return;
}
if (node_map.find(n) != node_map.end()) {
return;
}
NodeDef copy_def = n->def();
Node* copy = g->AddNode(copy_def, &s);
if (!s.ok()) {
return;
}
for (auto e : n->in_edges()) {
if (node_map.find(e->src()) == node_map.end()) {
s = errors::Internal("Cannot find node image for ",
e->src()->DebugString());
return;
}
g->AddEdge(node_map[e->src()], e->src_output(), copy, e->dst_input());
}
node_map[n] = copy;
Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
// Reverse DFS from send_from_host_main_graph, and stop at start_node.
struct Visit {
Node* n;
bool is_exiting;
};
// TODO(b/77601805): consolidate copy graph functions.
ReverseDFSFrom(*host_graph,
std::vector<const Node*>{send_from_host_main_graph},
/*enter=*/nullptr, copy_node_fn, NodeComparatorID());
if (!s.ok()) {
return s;
std::vector<Visit> stack{{send_from_host_main_graph, false}};
std::map<Node*, Node*> node_map;
node_map[host_graph->source_node()] = g->source_node();
while (!stack.empty()) {
Visit& curr = stack.back();
if (curr.is_exiting) {
if (node_map.find(curr.n) == node_map.end()) {
Node* copy = g->CopyNode(curr.n);
if (curr.n != start_node) {
for (const Edge* e : curr.n->in_edges()) {
auto node_iter = node_map.find(e->src());
if (node_iter == node_map.end()) {
return errors::Internal("Cannot find node image for ",
e->src()->DebugString());
}
g->AddEdge(node_iter->second, e->src_output(), copy,
e->dst_input());
}
}
node_map[curr.n] = copy;
}
stack.pop_back();
} else {
curr.is_exiting = true;
if (curr.n != start_node) {
for (const Edge* e : curr.n->in_edges()) {
if (node_map.find(e->src()) != node_map.end()) {
continue;
}
stack.push_back({e->src(), false});
}
}
}
}
send_from_host = node_map[send_from_host_main_graph];
@ -789,7 +807,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
std::unordered_set<const Node*>{send_from_host});
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(shape_inference_graph_name, *g, fld);
DumpGraphToFile(shape_inference_graph_name, *g, fld);
}
// Replace original shape inference graph.
@ -831,14 +849,9 @@ Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
device_ordinal_attr.set_i(0);
protobuf::Map<string, AttrValue> attrs;
attrs["device_ordinal"] = device_ordinal_attr;
FunctionBody* fbody = nullptr;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fld->Find(func_name), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
&fbody));
std::unique_ptr<FunctionBody> fbody_deleter(fbody);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(func_name),
AttrSlice(&attrs), fld, &fbody));
Graph* g = fbody->graph;
// Find or create the key placeholder node.
@ -962,14 +975,10 @@ Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld,
const string& while_node_name,
const string& host_transfer_key) {
// Instantiate the loop cond function.
FunctionBody* fbody = nullptr;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
&fbody));
std::unique_ptr<FunctionBody> fbody_deleter(fbody);
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(loop_cond_func.name()),
AttrSlice(&loop_cond_func.attr()),
fld, &fbody));
Graph* g = fbody->graph;
// Find the _Retval node and the loop cond node.
@ -1033,14 +1042,9 @@ Status RewriteHostWhileLoopCond(
device_ordinal_temp_value.set_i(0);
protobuf::Map<string, AttrValue> attrs;
attrs["device_ordinal"] = device_ordinal_temp_value;
FunctionBody* cond_fbody = nullptr;
std::unique_ptr<FunctionBody> cond_fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fld->Find(cond_host_func_name), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
&cond_fbody));
std::unique_ptr<FunctionBody> cond_fbody_deleter(cond_fbody);
*fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, &cond_fbody));
Graph* cond_graph = cond_fbody->graph;
Node* key_arg = nullptr;
for (Node* n : cond_graph->nodes()) {
@ -1113,14 +1117,9 @@ Status RewriteHostWhileLoopBody(
device_ordinal_temp_value.set_i(0);
protobuf::Map<string, AttrValue> attrs;
attrs["device_ordinal"] = device_ordinal_temp_value;
FunctionBody* body_fbody = nullptr;
std::unique_ptr<FunctionBody> body_fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fld->Find(body_host_func_name), AttrSlice(&attrs), fld,
[&](const string& op, const OpDef** sig) {
return fld->LookUpOpDef(op, sig);
},
&body_fbody));
std::unique_ptr<FunctionBody> body_fbody_deleter(body_fbody);
*fld->Find(body_host_func_name), AttrSlice(&attrs), fld, &body_fbody));
Graph* body_graph = body_fbody->graph;
Node* key_arg = nullptr;
for (Node* n : body_graph->nodes()) {
@ -1615,12 +1614,17 @@ Status ExtractOutsideCompilationForFunction(
// We cannot early return here, because we might have outside compilation in
// If/While function body.
// Find dependencies between outside compilation clusters.
TF_ASSIGN_OR_RETURN(auto cluster_deps,
OutsideCompilationClusterDependencies(
fbody->graph, outside_compilation_attr_name));
// Preprocess edges between different outside compilations. They will be
// restored in `ConstructHostGraph()`.
TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
fbody->graph, outside_compilation_attr_name));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
DumpGraphToFile(
absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
*fbody->graph, fld);
}
@ -1666,10 +1670,35 @@ Status ExtractOutsideCompilationForFunction(
}
}
}
std::map<string, Node*> host_compute_nodes;
for (Node* n : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n));
TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core));
auto host_compute_node_or = ReplaceOrRemoveOutsideCompilationCallNode(
graph_out.get(), n, host_compute_core, *cluster_deps);
TF_RETURN_IF_ERROR(host_compute_node_or.status());
Node* host_compute_node = host_compute_node_or.ValueOrDie();
if (host_compute_node) {
host_compute_nodes[host_compute_node->name()] = host_compute_node;
}
}
// For XlaHostCompute nodes with dependencies, add control edges between them
// so XlaCompiler can handle them in correct order.
for (auto iter : host_compute_nodes) {
Node* host_compute_node = iter.second;
std::vector<string> token_input_node_names;
TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
kXlaTokenInputNodesAttrName,
&token_input_node_names));
for (const string& node_name : token_input_node_names) {
if (node_name == kXlaTokenArgNodeName) {
continue;
}
auto iter = host_compute_nodes.find(node_name);
if (iter != host_compute_nodes.end()) {
graph_out->AddControlEdge(iter->second, host_compute_node);
}
}
}
// Handle nodes with associated functions.
@ -1705,7 +1734,7 @@ Status ExtractOutsideCompilationForFunction(
TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef));
}
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
DumpGraphToFile(
absl::StrCat("extract_outside_compilation_for_func_after_", func_name),
*graph_out, fld);
}
@ -1717,18 +1746,21 @@ Status ExtractOutsideCompilation(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
bool* modified) {
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile("extract_outside_compilation_before", *g, fld);
DumpGraphToFile("extract_outside_compilation_before", *g, fld);
}
std::vector<string> shape_inference_graphs;
*modified = false;
auto node_name_index = g->BuildNodeNameIndex();
for (auto& iter : clusters) {
string xla_cluster_name = iter.first;
Node* n = iter.second.node;
auto const& func_name_attrs = iter.second.func_name_attrs;
auto const& host_compute_core = iter.second.host_compute_core;
std::vector<string> shape_inference_graphs;
bool has_outside_compilation;
string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name());
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
@ -1736,18 +1768,23 @@ Status ExtractOutsideCompilation(
func_name_attrs, func_name_attrs.name(), host_graph_func_name,
host_compute_core, flr, fld, &shape_inference_graphs,
&has_outside_compilation));
TF_RETURN_IF_ERROR(
ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n));
TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
}
*modified |= has_outside_compilation;
for (auto shape_inference_graph_name : shape_inference_graphs) {
TF_RETURN_IF_ERROR(
RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld));
string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
Node* pivot_node = node_name_index[pivot_name];
TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
g, fld, host_graph_func_name, n, pivot_node));
TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
for (auto shape_inference_graph_name : shape_inference_graphs) {
TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph_name,
g, pivot_node, fld));
}
}
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile("extract_outside_compilation_after", *g, fld);
DumpGraphToFile("extract_outside_compilation_after", *g, fld);
}
return Status::OK();
}

View File

@ -101,7 +101,8 @@ Status ExtractOutsideCompilation(
const string& xla_cluster_attr_name,
const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld);
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
bool* modified);
} // namespace tensorflow

View File

@ -300,14 +300,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) {
&has_outside_compilation));
// Get rewritten XLA computation function.
FunctionBody *xla_fbody = nullptr;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("cluster_rewritten"), AttrSlice(), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&xla_fbody));
std::unique_ptr<FunctionBody> xla_fbody_deleter(xla_fbody);
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
// Check XlaHostCompute nodes.
@ -343,18 +338,13 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) {
EXPECT_EQ(shape_inference_graphs.size(), 0);
// Check host graph: verify we have key placeholder and sequencer.
FunctionBody *host_fbody = nullptr;
std::unique_ptr<FunctionBody> host_fbody;
AttrValue device_ordinal_temp_value;
device_ordinal_temp_value.set_i(0);
protobuf::Map<string, AttrValue> host_func_attrs;
host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&host_fbody));
std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody);
*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody));
Graph *host_graph = host_fbody->graph;
Node *key_placeholder = nullptr, *sequencer = nullptr;
for (Node *n : host_graph->nodes()) {
@ -428,18 +418,13 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) {
&has_outside_compilation));
// Check host graph is empty.
FunctionBody *host_fbody = nullptr;
std::unique_ptr<FunctionBody> host_fbody;
AttrValue device_ordinal_temp_value;
device_ordinal_temp_value.set_i(0);
protobuf::Map<string, AttrValue> host_func_attrs;
host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&host_fbody));
std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody);
*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody));
Graph *host_graph = host_fbody->graph;
EXPECT_EQ(host_graph->num_nodes(), 2);
}
@ -476,31 +461,21 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) {
&has_outside_compilation));
// Check rewritten XLA graph: verify that we have no XlaHostCompute.
FunctionBody *xla_fbody = nullptr;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("cluster_rewritten"), AttrSlice(), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&xla_fbody));
std::unique_ptr<FunctionBody> xla_fbody_deleter(xla_fbody);
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
for (Node *n : xla_fbody->graph->nodes()) {
EXPECT_NE(n->type_string(), "XlaHostCompute");
}
// Check host graph: verify we have no placeholder, but we have "const1".
FunctionBody *host_fbody = nullptr;
std::unique_ptr<FunctionBody> host_fbody;
AttrValue device_ordinal_temp_value;
device_ordinal_temp_value.set_i(0);
protobuf::Map<string, AttrValue> host_func_attrs;
host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&host_fbody));
std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody);
*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody));
Graph *host_graph = host_fbody->graph;
int num_key_placeholders = 0;
for (Node *n : host_graph->nodes()) {
@ -600,18 +575,14 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) {
// Check host graph.
{
FunctionBody *host_fbody = nullptr;
std::unique_ptr<FunctionBody> host_fbody;
AttrValue device_ordinal_temp_value;
device_ordinal_temp_value.set_i(0);
protobuf::Map<string, AttrValue> host_func_attrs;
host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&host_fbody));
std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody);
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"),
AttrSlice(&host_func_attrs), &fld,
&host_fbody));
Graph *host_graph = host_fbody->graph;
auto node_name_index = host_graph->BuildNodeNameIndex();
@ -654,14 +625,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) {
// Check XLA graph.
{
FunctionBody *xla_fbody = nullptr;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("cluster_rewritten"), AttrSlice(), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&xla_fbody));
std::unique_ptr<FunctionBody> xla_fbody_deleter(xla_fbody);
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
Graph *xla_graph = xla_fbody->graph;
auto node_name_index = xla_graph->BuildNodeNameIndex();
@ -759,18 +725,14 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) {
// Check host graph.
{
FunctionBody *host_fbody = nullptr;
std::unique_ptr<FunctionBody> host_fbody;
AttrValue device_ordinal_temp_value;
device_ordinal_temp_value.set_i(0);
protobuf::Map<string, AttrValue> host_func_attrs;
host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&host_fbody));
std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody);
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"),
AttrSlice(&host_func_attrs), &fld,
&host_fbody));
Graph *host_graph = host_fbody->graph;
auto node_name_index = host_graph->BuildNodeNameIndex();
@ -899,18 +861,14 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
// Check host graph.
{
FunctionBody *host_fbody = nullptr;
std::unique_ptr<FunctionBody> host_fbody;
AttrValue device_ordinal_temp_value;
device_ordinal_temp_value.set_i(0);
protobuf::Map<string, AttrValue> host_func_attrs;
host_func_attrs["device_ordinal"] = device_ordinal_temp_value;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&host_fbody));
std::unique_ptr<FunctionBody> host_fbody_deleter(host_fbody);
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"),
AttrSlice(&host_func_attrs), &fld,
&host_fbody));
Graph *host_graph = host_fbody->graph;
auto node_name_index = host_graph->BuildNodeNameIndex();
@ -918,14 +876,10 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
Node *call_node = node_name_index["oc_call_fn"];
EXPECT_NE(call_node, nullptr);
FunctionBody *call_fbody = nullptr;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("oc_func_call_host_fn"), AttrSlice(&host_func_attrs), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&call_fbody));
std::unique_ptr<FunctionBody> call_fbody_deleter(call_fbody);
std::unique_ptr<FunctionBody> call_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("oc_func_call_host_fn"),
AttrSlice(&host_func_attrs), &fld,
&call_fbody));
// Verify we have _XlaRecvAtHost and _XlaSendFromHost nodes.
bool has_recv = false, has_send = false;
@ -942,14 +896,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
// Check XLA graph.
{
FunctionBody *xla_fbody = nullptr;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("cluster_rewritten"), AttrSlice(), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&xla_fbody));
std::unique_ptr<FunctionBody> xla_fbody_deleter(xla_fbody);
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
Graph *xla_graph = xla_fbody->graph;
auto node_name_index = xla_graph->BuildNodeNameIndex();
@ -958,14 +907,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
EXPECT_NE(fn_node, nullptr);
EXPECT_EQ(fn_node->type_string(), "fn_oc");
FunctionBody *call_fbody = nullptr;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fld.Find("fn_oc"), AttrSlice(), &fld,
[&](const string &op, const OpDef **sig) {
return fld.LookUpOpDef(op, sig);
},
&call_fbody));
std::unique_ptr<FunctionBody> call_fbody_deleter(call_fbody);
std::unique_ptr<FunctionBody> call_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("fn_oc"), AttrSlice(), &fld,
&call_fbody));
// Verify we have XlaHostCompute nodes.
bool has_hc = false;
@ -978,4 +922,165 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) {
}
}
TEST_F(ExtractOutsideCompilationForFunctionTest,
OutsideCompilationClusterDataDependency) {
// Build the XLA computation func.
// "const0"
// "identity0" = "const0" (outside compilation cluster "0")
// "identity1" = "identity0" (outside compilation cluster "1")
// "identity2" = "identity1"
FunctionDefLibrary fdl;
{
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
<< std::endl;
auto node_name_image = g->BuildNodeNameIndex();
node_name_image["identity0"]->AddAttr("_oc", "0");
node_name_image["identity1"]->AddAttr("_oc", "1");
PartialTensorShape shape({2});
node_name_image["identity1"]->AddAttr(
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
protobuf::Map<string, tensorflow::AttrValue> attrs;
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
std::vector<string> shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationTest(
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Get rewritten XLA computation function.
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
// Check XlaHostCompute nodes.
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
EXPECT_NE(host_compute_0, nullptr);
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
EXPECT_NE(host_compute_1, nullptr);
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
std::vector<string> token_input_nodes;
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
token_input_nodes.clear();
std::vector<string> expected_token_input_nodes_1(
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
// Check there is a control edge from host_compute_0 to host_compute_1.
bool has_control_edge = false;
for (const Edge *e : host_compute_1->in_edges()) {
if (e->IsControlEdge() && e->src() == host_compute_0) {
has_control_edge = true;
break;
}
}
EXPECT_TRUE(has_control_edge);
}
TEST_F(ExtractOutsideCompilationForFunctionTest,
OutsideCompilationClusterControlDependency) {
// Build the XLA computation func.
// "const0"
// "identity0" = "const0" (outside compilation cluster "0")
// "identity1" = "const0" "^identity0" (outside compilation cluster "1",
// control depdent on cluster "0")
// "identity2" = "identity1"
FunctionDefLibrary fdl;
{
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const0 = ops::Const(s.WithOpName("const0"), 1, {2});
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
Output identity1 = ops::Identity(
s.WithOpName("identity1").WithControlDependencies(identity0), const0);
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
TF_CHECK_OK(s.ToGraph(g.get()));
std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString()
<< std::endl;
auto node_name_image = g->BuildNodeNameIndex();
node_name_image["identity0"]->AddAttr("_oc", "0");
node_name_image["identity1"]->AddAttr("_oc", "1");
PartialTensorShape shape({2});
node_name_image["identity1"]->AddAttr(
kXlaInferredShapesAttrName, std::vector<PartialTensorShape>{shape});
FunctionDef *xla_fdef = fdl.add_function();
TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef));
}
FunctionLibraryDefinition fld(OpRegistry::Global(), fdl);
protobuf::Map<string, tensorflow::AttrValue> attrs;
std::map<string, int> host_compute_core = {{"0", 1}, {"1", 0}};
std::vector<string> shape_inference_graphs;
bool has_outside_compilation;
NameAttrList name_attrs;
name_attrs.set_name("cluster");
*name_attrs.mutable_attr() = attrs;
TF_CHECK_OK(ExtractOutsideCompilationTest(
"_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph",
host_compute_core, &fld, &shape_inference_graphs,
&has_outside_compilation));
// Get rewritten XLA computation function.
std::unique_ptr<FunctionBody> xla_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"),
AttrSlice(), &fld, &xla_fbody));
auto node_name_index = xla_fbody->graph->BuildNodeNameIndex();
// Check XlaHostCompute nodes.
Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"];
EXPECT_NE(host_compute_0, nullptr);
Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"];
EXPECT_NE(host_compute_1, nullptr);
// Check XlaHostCompute nodes' "_xla_token_input_nodes" attr.
std::vector<string> token_input_nodes;
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
std::vector<string> expected_token_input_nodes_0({"_xla_token_arg_node"});
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0);
token_input_nodes.clear();
std::vector<string> expected_token_input_nodes_1(
{"_xla_token_arg_node", "outside_compilation_0_host_compute"});
TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()),
"_xla_token_input_nodes", &token_input_nodes));
EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1);
// Check there is a control edge from host_compute_0 to host_compute_1.
bool has_control_edge = false;
for (const Edge *e : host_compute_1->in_edges()) {
if (e->IsControlEdge() && e->src() == host_compute_0) {
has_control_edge = true;
break;
}
}
EXPECT_TRUE(has_control_edge);
}
} // namespace tensorflow

View File

@ -15,6 +15,9 @@ limitations under the License.
#include <mutex> // NOLINT
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/util/command_line_flags.h"
@ -23,30 +26,50 @@ namespace tensorflow {
namespace {
BuildXlaOpsPassFlags* build_ops_flags;
DumpGraphFlags* dump_graph_flags;
MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
std::vector<Flag>* flag_list;
std::once_flag flags_init;
void AppendDumpGraphFlagsInternal(std::vector<Flag>* flag_list) {
std::vector<Flag> new_flags = {
Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix,
"Path prefix to which graphs dumped during debugging should be "
"written."),
};
flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
bool SetterForXlaAutoJitFlag(const string& value) {
int32 opt_level;
// We need to use the mark_for_compilation_flags directly here instead of
// going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The
// latter will try to setup and parse flags, which would bring us back to this
// setter.
if (absl::SimpleAtoi(value, &opt_level)) {
mark_for_compilation_flags->xla_auto_jit_flag
.optimization_level_single_gpu = opt_level;
mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
opt_level;
return true;
}
absl::string_view value_sv(value);
if (!absl::ConsumePrefix(&value_sv, "single-gpu(") ||
!absl::ConsumeSuffix(&value_sv, ")") ||
!absl::SimpleAtoi(value_sv, &opt_level)) {
return false;
}
mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
opt_level;
return true;
}
void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
std::vector<Flag> new_flags = {
Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit,
Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
"Control compilation of operators into XLA computations on CPU and "
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
"things very likely to be improved; 2 = on for everything. "
"Experimental."),
"If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
"graphs (graphs that have at least one node placed on a GPU and no "
"more than one GPU is in use through the entire graph) and 0 "
"otherwise. Experimental."),
Flag("tf_xla_min_cluster_size",
&mark_for_compilation_flags->tf_xla_min_cluster_size,
"Minimum number of operators in an XLA compilation. Ignored for "
@ -65,10 +88,6 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
&mark_for_compilation_flags->tf_xla_clustering_fuel,
"Places an artificial limit on the number of ops marked as "
"eligible for clustering."),
Flag("tf_xla_fusion_only",
&mark_for_compilation_flags->tf_xla_fusion_only,
"enable fusion of element-wise operations only using XLA when "
"global_jit_level is ON*."),
Flag("tf_xla_disable_deadness_safety_checks_for_debugging",
&mark_for_compilation_flags
->tf_xla_disable_deadness_safety_checks_for_debugging,
@ -80,20 +99,19 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
void AllocateAndParseFlags() {
build_ops_flags = new BuildXlaOpsPassFlags;
build_ops_flags->tf_xla_enable_lazy_compilation = true;
dump_graph_flags = new DumpGraphFlags;
dump_graph_flags->tf_dump_graph_prefix = "/tmp/";
build_ops_flags->tf_xla_print_cluster_outputs = false;
mark_for_compilation_flags = new MarkForCompilationPassFlags;
mark_for_compilation_flags->tf_xla_auto_jit = 0;
mark_for_compilation_flags->tf_xla_min_cluster_size = 2;
mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
0;
mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = 0;
mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
mark_for_compilation_flags->tf_xla_max_cluster_size =
std::numeric_limits<int32>::max();
mark_for_compilation_flags->tf_xla_clustering_debug = false;
mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
mark_for_compilation_flags->tf_xla_clustering_fuel =
std::numeric_limits<int64>::max();
mark_for_compilation_flags->tf_xla_fusion_only = false;
mark_for_compilation_flags
->tf_xla_disable_deadness_safety_checks_for_debugging = false;
@ -103,32 +121,52 @@ void AllocateAndParseFlags() {
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
flag_list = new std::vector<Flag>({
Flag("tf_xla_enable_lazy_compilation",
&build_ops_flags->tf_xla_enable_lazy_compilation, ""),
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true;
};
flag_list = new std::vector<Flag>(
{Flag("tf_xla_enable_lazy_compilation",
&build_ops_flags->tf_xla_enable_lazy_compilation, ""),
Flag("tf_xla_print_cluster_outputs",
&build_ops_flags->tf_xla_print_cluster_outputs,
"If true then insert Print nodes to print out values produced by "
"XLA clusters."),
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
Flag("tf_introduce_floating_point_jitter_to_tensors",
setter_for_jitter_tensor_names, "",
"The Tensors to add the jitter to. The tensors are named in the "
"TensorId format of <node name>:<output idx>."),
Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")});
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
});
AppendDumpGraphFlagsInternal(flag_list);
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
}
} // namespace
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
bool SetXlaAutoJitFlagFromFlagString(const string& value) {
std::call_once(flags_init, &AllocateAndParseFlags);
return *build_ops_flags;
return SetterForXlaAutoJitFlag(value);
}
DumpGraphFlags* GetDumpGraphFlags() {
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return dump_graph_flags;
return build_ops_flags;
}
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
@ -146,14 +184,14 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
return *ops_flags;
}
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return *jitter_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
void AppendDumpGraphFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags);
AppendDumpGraphFlagsInternal(flag_list);
}
} // namespace tensorflow

View File

@ -23,14 +23,30 @@ limitations under the License.
namespace tensorflow {
// Flags associated with the XLA bridge's mark_for_compilation_pass module.
struct MarkForCompilationPassFlags {
struct XlaAutoJitFlag {
// Control compilation of operators into XLA computations on CPU and GPU
// devices. 0 = use ConfigProto setting; -1 = off; 1 = on for things very
// likely to be improved; 2 = on for everything.
//
// If all non-CPU ops in the graph being optimized are placed on a single GPU
// and there is at least one node placed on that GPU then
// `optimization_level_single_gpu` applies. Otherwise
// `optimization_level_general` applies.
//
// Experimental.
int32 tf_xla_auto_jit;
int32 optimization_level_single_gpu;
int32 optimization_level_general;
};
// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax
// is:
// <number>: sets general and single_gpu setting to the provided number.
// single-gpu(<number>): sets the single_gpu setting to the provided number.
bool SetXlaAutoJitFlagFromFlagString(const string& value);
// Flags associated with the XLA bridge's mark_for_compilation_pass module.
struct MarkForCompilationPassFlags {
XlaAutoJitFlag xla_auto_jit_flag;
// Minimum number of operators in an XLA compilation. Ignored for operators
// placed on an XLA device or operators explicitly marked for compilation.
@ -49,11 +65,6 @@ struct MarkForCompilationPassFlags {
// eligible for clustering.
int64 tf_xla_clustering_fuel;
// tf_xla_fusion_only is effective only when global_jit_level is set to ON*
// and overrides its behavior. If true, enable fusion of element-wise
// operations only using XLA.
bool tf_xla_fusion_only;
// If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then
// we do not do deadness related safety checks. This is unsound in general,
// but can be used as a debugging aid.
@ -81,12 +92,21 @@ struct BuildXlaOpsPassFlags {
// Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
// Defaults to true.
bool tf_xla_enable_lazy_compilation;
// If true then insert Print nodes to print out values produced by XLA
// clusters. Useful for debugging.
bool tf_xla_print_cluster_outputs;
};
// Flags for the XLA bridge's dump_graph module.
struct DumpGraphFlags {
// Path prefix to which graphs dumped during debugging should be written.
string tf_dump_graph_prefix;
// Flags for the IntroduceFloatingPointJitter pass.
struct IntroduceFloatingPointJitterPassFlags {
// The amount of jitter to introduce. This amount is added to each element in
// the tensors named in `tensor_names.
float jitter_amount;
// The Tensors to add the jitter to. The tensors are named in the TensorId
// format of <node name>:<output idx>.
std::vector<string> tensor_names;
};
// Return a pointer to the DumpGraphFlags struct;
@ -97,10 +117,12 @@ struct DumpGraphFlags {
// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer
// always return the same pointer.
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags();
XlaDeviceFlags* GetXlaDeviceFlags();
const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
DumpGraphFlags* GetDumpGraphFlags();
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags();
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
@ -108,8 +130,6 @@ DumpGraphFlags* GetDumpGraphFlags();
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
void AppendDumpGraphFlags(std::vector<tensorflow::Flag>* flag_list);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -13,8 +13,23 @@ cc_library(
srcs = ["graphcycles.cc"],
hdrs = ["graphcycles.h"],
deps = [
":ordered_set",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "ordered_set",
hdrs = ["ordered_set.h"],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:span",
],
)
@ -28,3 +43,14 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "ordered_set_test",
srcs = ["ordered_set_test.cc"],
deps = [
":ordered_set",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -34,14 +34,20 @@ limitations under the License.
#include <algorithm>
#include <unordered_set>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
typedef std::unordered_set<int32> NodeSet;
using NodeSet = absl::flat_hash_set<int32>;
using OrderedNodeSet = OrderedSet<int32>;
template <typename T>
struct VecStruct {
typedef absl::InlinedVector<T, 4> type;
@ -50,13 +56,11 @@ template <typename T>
using Vec = typename VecStruct<T>::type;
struct Node {
Node() : in(4), out(4) {} // Small hashtables for in/out edges
int32 rank; // rank number assigned by Pearce-Kelly algorithm
bool visited; // Temporary marker used by depth-first-search
void* data; // User-supplied data
NodeSet in; // List of immediate predecessor nodes in graph
NodeSet out; // List of immediate successor nodes in graph
OrderedNodeSet in; // List of immediate predecessor nodes in graph
OrderedNodeSet out; // List of immediate successor nodes in graph
};
} // namespace
@ -93,7 +97,7 @@ bool GraphCycles::CheckInvariants() const {
if (!ranks.insert(nx->rank).second) {
LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank;
}
for (auto y : nx->out) {
for (int32 y : nx->out.GetSequence()) {
Node* ny = r->nodes_[y];
if (nx->rank >= ny->rank) {
LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment "
@ -124,14 +128,14 @@ int32 GraphCycles::NewNode() {
void GraphCycles::RemoveNode(int32 node) {
Node* x = rep_->nodes_[node];
for (auto y : x->out) {
rep_->nodes_[y]->in.erase(node);
for (int32 y : x->out.GetSequence()) {
rep_->nodes_[y]->in.Erase(node);
}
for (auto y : x->in) {
rep_->nodes_[y]->out.erase(node);
for (int32 y : x->in.GetSequence()) {
rep_->nodes_[y]->out.Erase(node);
}
x->in.clear();
x->out.clear();
x->in.Clear();
x->out.Clear();
rep_->free_nodes_.push_back(node);
}
@ -144,12 +148,12 @@ void GraphCycles::SetNodeData(int32 node, void* data) {
}
bool GraphCycles::HasEdge(int32 x, int32 y) const {
return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end();
return rep_->nodes_[x]->out.Contains(y);
}
void GraphCycles::RemoveEdge(int32 x, int32 y) {
rep_->nodes_[x]->out.erase(y);
rep_->nodes_[y]->in.erase(x);
rep_->nodes_[x]->out.Erase(y);
rep_->nodes_[y]->in.Erase(x);
// No need to update the rank assignment since a previous valid
// rank assignment remains valid after an edge deletion.
}
@ -165,13 +169,13 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
if (x == y) return false;
Rep* r = rep_;
Node* nx = r->nodes_[x];
if (!nx->out.insert(y).second) {
if (!nx->out.Insert(y)) {
// Edge already exists.
return true;
}
Node* ny = r->nodes_[y];
ny->in.insert(x);
ny->in.Insert(x);
if (nx->rank <= ny->rank) {
// New edge is consistent with existing rank assignment.
@ -182,8 +186,8 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) {
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
if (!ForwardDFS(r, y, nx->rank)) {
// Found a cycle. Undo the insertion and tell caller.
nx->out.erase(y);
ny->in.erase(x);
nx->out.Erase(y);
ny->in.Erase(x);
// Since we do not call Reorder() on this path, clear any visited
// markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf_);
@ -209,7 +213,7 @@ static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) {
nn->visited = true;
r->deltaf_.push_back(n);
for (auto w : nn->out) {
for (auto w : nn->out.GetSequence()) {
Node* nw = r->nodes_[w];
if (nw->rank == upper_bound) {
return false; // Cycle
@ -235,7 +239,7 @@ static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) {
nn->visited = true;
r->deltab_.push_back(n);
for (auto w : nn->in) {
for (auto w : nn->in.GetSequence()) {
Node* nw = r->nodes_[w];
if (!nw->visited && lower_bound < nw->rank) {
r->stack_.push_back(w);
@ -321,7 +325,7 @@ int GraphCycles::FindPath(int32 x, int32 y, int max_path_len,
return path_len;
}
for (auto w : r->nodes_[n]->out) {
for (auto w : r->nodes_[n]->out.GetSequence()) {
if (seen.insert(w).second) {
r->stack_.push_back(w);
}
@ -375,31 +379,94 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
}
Node* nb = rep_->nodes_[b];
std::unordered_set<int32> out = std::move(nb->out);
std::unordered_set<int32> in = std::move(nb->in);
for (auto y : out) {
rep_->nodes_[y]->in.erase(b);
OrderedNodeSet out = std::move(nb->out);
OrderedNodeSet in = std::move(nb->in);
for (int32 y : out.GetSequence()) {
rep_->nodes_[y]->in.Erase(b);
}
for (auto y : in) {
rep_->nodes_[y]->out.erase(b);
for (int32 y : in.GetSequence()) {
rep_->nodes_[y]->out.Erase(b);
}
rep_->free_nodes_.push_back(b);
for (auto y : out) {
rep_->nodes_[a]->out.Reserve(rep_->nodes_[a]->out.Size() + out.Size());
for (int32 y : out.GetSequence()) {
InsertEdge(a, y);
}
for (auto y : in) {
rep_->nodes_[a]->in.Reserve(rep_->nodes_[a]->in.Size() + in.Size());
for (int32 y : in.GetSequence()) {
InsertEdge(y, a);
}
return true;
}
std::unordered_set<int32> GraphCycles::Successors(int32 node) {
return rep_->nodes_[node]->out;
absl::Span<const int32> GraphCycles::Successors(int32 node) const {
return rep_->nodes_[node]->out.GetSequence();
}
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) {
return rep_->nodes_[node]->in;
absl::Span<const int32> GraphCycles::Predecessors(int32 node) const {
return rep_->nodes_[node]->in.GetSequence();
}
std::vector<int32> GraphCycles::SuccessorsCopy(int32 node) const {
absl::Span<const int32> successors = Successors(node);
return std::vector<int32>(successors.begin(), successors.end());
}
std::vector<int32> GraphCycles::PredecessorsCopy(int32 node) const {
absl::Span<const int32> predecessors = Predecessors(node);
return std::vector<int32>(predecessors.begin(), predecessors.end());
}
namespace {
void SortInPostOrder(absl::Span<Node* const> nodes,
std::vector<int32>* to_sort) {
absl::c_sort(*to_sort, [&](int32 a, int32 b) {
DCHECK(a == b || nodes[a]->rank != nodes[b]->rank);
return nodes[a]->rank > nodes[b]->rank;
});
}
} // namespace
std::vector<int32> GraphCycles::AllNodesInPostOrder() const {
absl::flat_hash_set<int32> free_nodes_set;
absl::c_copy(rep_->free_nodes_,
std::inserter(free_nodes_set, free_nodes_set.begin()));
std::vector<int32> all_nodes;
all_nodes.reserve(rep_->nodes_.size() - free_nodes_set.size());
for (int64 i = 0, e = rep_->nodes_.size(); i < e; i++) {
if (!free_nodes_set.contains(i)) {
all_nodes.push_back(i);
}
}
SortInPostOrder(rep_->nodes_, &all_nodes);
return all_nodes;
}
string GraphCycles::DebugString() const {
absl::flat_hash_set<int32> free_nodes_set;
for (int32 free_node : rep_->free_nodes_) {
free_nodes_set.insert(free_node);
}
string result = "digraph {\n";
for (int i = 0; i < rep_->nodes_.size(); i++) {
if (free_nodes_set.contains(i)) {
continue;
}
for (int32 succ : rep_->nodes_[i]->out.GetSequence()) {
absl::StrAppend(&result, " \"", i, "\" -> \"", succ, "\"\n");
}
}
absl::StrAppend(&result, "}\n");
return result;
}
} // namespace tensorflow

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
#include <vector>
// GraphCycles detects the introduction of a cycle into a directed
// graph that is being built up incrementally.
//
@ -38,8 +40,7 @@ limitations under the License.
// FindPath() is linear in the size of the graph.
// The current implementation uses O(|V|+|E|) space.
#include <unordered_set>
#include "absl/types/span.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@ -117,8 +118,26 @@ class GraphCycles {
// Expensive: should only be called from graphcycles_test.cc.
bool CheckInvariants() const;
std::unordered_set<int32> Successors(int32 node);
std::unordered_set<int32> Predecessors(int32 node);
// Warning: Do not use these if iterating over the span and modifying the
// GraphCycles at the same time. Instead use SuccessorsCopy/PredecessorsCopy.
absl::Span<const int32> Successors(int32 node) const;
absl::Span<const int32> Predecessors(int32 node) const;
// Return a copy of the sucessors set. This is needed for code using the
// collection while modifying the GraphCycles.
std::vector<int32> SuccessorsCopy(int32 node) const;
// Return a copy of the predecessors set. This is needed for code using the
// collection while modifying the GraphCycles.
std::vector<int32> PredecessorsCopy(int32 node) const;
// Returns all nodes in post order.
//
// If there is a path from X to Y then X appears after Y in the
// returned vector.
std::vector<int32> AllNodesInPostOrder() const;
// Returns the graph in graphviz format.
string DebugString() const;
// ----------------------------------------------------
struct Rep;

View File

@ -0,0 +1,85 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
// This is a set data structure that provides a deterministic iteration order.
// The iteration order of elements only depends on the sequence of
// inserts/deletes, so as long as the inserts/deletes happen in the same
// sequence, the set will have the same iteration order.
//
// Assumes that T can be cheaply copied for simplicity.
template <typename T>
class OrderedSet {
public:
// Inserts `value` into the ordered set. Returns true if the value was not
// present in the set before the insertion.
bool Insert(T value) {
bool new_insertion =
value_to_index_.insert({value, value_sequence_.size()}).second;
if (new_insertion) {
value_sequence_.push_back(value);
}
return new_insertion;
}
// Removes `value` from the set. Assumes `value` is already present in the
// set.
void Erase(T value) {
auto it = value_to_index_.find(value);
DCHECK(it != value_to_index_.end());
// Since we don't want to move values around in `value_sequence_` we swap
// the value in the last position and with value to be deleted and then
// pop_back.
value_to_index_[value_sequence_.back()] = it->second;
std::swap(value_sequence_[it->second], value_sequence_.back());
value_sequence_.pop_back();
value_to_index_.erase(it);
}
void Reserve(size_t new_size) {
value_to_index_.reserve(new_size);
value_sequence_.reserve(new_size);
}
void Clear() {
value_to_index_.clear();
value_sequence_.clear();
}
bool Contains(T value) const { return value_to_index_.contains(value); }
size_t Size() const { return value_sequence_.size(); }
absl::Span<T const> GetSequence() const { return value_sequence_; }
private:
// The stable order that we maintain through insertions and deletions.
std::vector<T> value_sequence_;
// Maps values to their indices in `value_sequence_`.
absl::flat_hash_map<T, int> value_to_index_;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_

View File

@ -0,0 +1,117 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/graphcycles/ordered_set.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
TEST(OrderedSetTest, Insert) {
OrderedSet<int> ordered_set;
EXPECT_TRUE(ordered_set.Insert(90));
EXPECT_TRUE(ordered_set.Insert(100));
EXPECT_TRUE(ordered_set.Insert(80));
EXPECT_FALSE(ordered_set.Insert(100));
EXPECT_EQ(ordered_set.Size(), 3);
EXPECT_TRUE(ordered_set.Contains(90));
EXPECT_TRUE(ordered_set.Contains(100));
EXPECT_TRUE(ordered_set.Contains(80));
EXPECT_FALSE(ordered_set.Contains(40));
std::array<int, 3> expected_sequence = {90, 100, 80};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
}
TEST(OrderedSetTest, Erase) {
OrderedSet<int> ordered_set;
EXPECT_TRUE(ordered_set.Insert(90));
EXPECT_TRUE(ordered_set.Insert(100));
EXPECT_TRUE(ordered_set.Insert(80));
ordered_set.Erase(100);
EXPECT_EQ(ordered_set.Size(), 2);
EXPECT_TRUE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_TRUE(ordered_set.Contains(80));
std::array<int, 2> expected_sequence_0 = {90, 80};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_0);
ordered_set.Erase(80);
EXPECT_EQ(ordered_set.Size(), 1);
EXPECT_TRUE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_FALSE(ordered_set.Contains(80));
std::array<int, 1> expected_sequence_1 = {90};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_1);
ordered_set.Erase(90);
EXPECT_EQ(ordered_set.Size(), 0);
EXPECT_FALSE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_FALSE(ordered_set.Contains(80));
std::array<int, 0> expected_sequence_2 = {};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_2);
}
TEST(OrderedSetTest, Clear) {
OrderedSet<int> ordered_set;
EXPECT_TRUE(ordered_set.Insert(90));
EXPECT_TRUE(ordered_set.Insert(100));
EXPECT_TRUE(ordered_set.Insert(80));
ordered_set.Clear();
EXPECT_EQ(ordered_set.Size(), 0);
EXPECT_FALSE(ordered_set.Contains(90));
EXPECT_FALSE(ordered_set.Contains(100));
EXPECT_FALSE(ordered_set.Contains(80));
std::array<int, 0> expected_sequence = {};
EXPECT_EQ(ordered_set.GetSequence(), expected_sequence);
}
TEST(OrderedSetTest, LargeInsertions) {
const int kSize = 50 * 9000;
OrderedSet<int> ordered_set;
for (int i = 0; i < kSize; i++) {
EXPECT_TRUE(ordered_set.Insert(i + 500));
}
for (int i = 0; i < kSize; i++) {
EXPECT_EQ(ordered_set.GetSequence()[i], i + 500);
}
}
} // namespace
} // namespace tensorflow

View File

@ -27,12 +27,12 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace {
@ -375,15 +375,15 @@ Status IncreaseDynamismForAutoJitPass::Run(
const GraphOptimizationPassOptions& options) {
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
if (flags->tf_xla_clustering_debug) {
dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
**options.graph, options.flib_def);
DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
**options.graph, options.flib_def);
}
bool changed;
TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed));
if (changed && flags->tf_xla_clustering_debug) {
dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass",
**options.graph, options.flib_def);
DumpGraphToFile("increase_dynamism_for_auto_jit_pass", **options.graph,
options.flib_def);
}
return Status::OK();

Some files were not shown because too many files have changed in this diff Show More