Resolve merge conflicts and remove unnecessary __restrict keywords
This commit is contained in:
commit
de37020401
8
.bazelrc
8
.bazelrc
@ -30,6 +30,10 @@ build:monolithic --define framework_shared_object=false
|
||||
# opts in to modular op registration support by default.
|
||||
build --define framework_shared_object=true
|
||||
|
||||
# Flags for open source build, always set to be true.
|
||||
build --define open_source_build=true
|
||||
test --define open_source_build=true
|
||||
|
||||
# Please note that MKL on MacOS or windows is still not supported.
|
||||
# If you would like to use a local MKL instead of downloading, please set the
|
||||
# environment variable "TF_MKL_ROOT" every time before build.
|
||||
@ -108,6 +112,10 @@ build --spawn_strategy=standalone
|
||||
build --strategy=Genrule=standalone
|
||||
build -c opt
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
build --cxxopt=-std=c++14
|
||||
build --host_cxxopt=-std=c++14
|
||||
|
||||
# Make Bazel print out all options from rc files.
|
||||
build --announce_rc
|
||||
|
||||
|
@ -1,13 +1,14 @@
|
||||
# Where component owners are known, add them here.
|
||||
|
||||
/tensorflow/c/eager @jaingurav @alextp
|
||||
/tensorflow/c/eager @jaingaurav @alextp
|
||||
/tensorflow/core/common_runtime/eager @jaingaurav @alextp
|
||||
/tenosrflow/core/debug @caisq
|
||||
/tensorflow/core/nccl/ @azaks2 @chsigg
|
||||
/tensorflow/core/platform/windows/ @mrry
|
||||
/tensorflow/core/platform/s3 @yongtang
|
||||
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
||||
/tensorflow/python/debug @caisq
|
||||
/tensorflow/python/eager @jaingurav @alextp
|
||||
/tensorflow/python/eager @jaingaurav @alextp
|
||||
/tensorflow/python/tools/api/generator/ @annarev
|
||||
/tensorflow/tensorboard/ @jart
|
||||
/tensorflow/tools/docs/ @markdaoust
|
||||
@ -15,6 +16,7 @@
|
||||
# contrib
|
||||
|
||||
# NEED OWNER: /tensorflow/contrib/all_reduce
|
||||
/tensorflow/contrib/autograph/ @mdanatg @kkimdev
|
||||
/tensorflow/contrib/batching/ @alextp @chrisolston
|
||||
/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon
|
||||
/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva
|
||||
@ -26,11 +28,10 @@
|
||||
/tensorflow/contrib/data/ @mrry
|
||||
/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn
|
||||
/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi
|
||||
/tensorflow/contrib/eager @jaingurav @alextp
|
||||
/tensorflow/contrib/eager @jaingaurav @alextp
|
||||
/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo
|
||||
/tensorflow/contrib/ffmpeg/ @fredbertsch
|
||||
/tensorflow/contrib/framework/ @ebrevdo
|
||||
/tensorflow/contrib/gan/ @joel-shor
|
||||
/tensorflow/contrib/graph_editor/ @purpledog
|
||||
# NEED OWNER: /tensorflow/contrib/grid_rnn/
|
||||
/tensorflow/contrib/hadoop @yongtang
|
||||
|
@ -60,7 +60,13 @@ If you are experiencing or witnessing conflict, we ask you to use the following
|
||||
|
||||
## Reporting Violations
|
||||
|
||||
Violations of the Code of Conduct can be reported to TensorFlow’s Project Stewards, Edd Wilder-James (ewj@google.com) and Sarah Novotny (sarahnovotny@google.com). The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report.
|
||||
Violations of the Code of Conduct can be reported to TensorFlow’s Project
|
||||
Stewards, Edd Wilder-James (ewj@google.com) and Thea Lamkin
|
||||
(thealamkin@google.com). The Project Steward will determine whether the Code of
|
||||
Conduct was violated, and will issue an appropriate sanction, possibly including
|
||||
a written warning or expulsion from the project, project sponsored spaces, or
|
||||
project forums. We ask that you make a good-faith effort to resolve your
|
||||
conflict via the conflict resolution policy before submitting a report.
|
||||
|
||||
Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report.
|
||||
|
||||
|
@ -29,7 +29,8 @@ Follow either of the two links above to access the appropriate CLA and instructi
|
||||
### Contributing code
|
||||
|
||||
If you have improvements to TensorFlow, send us your pull requests! For those
|
||||
just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/).
|
||||
just getting started, Github has a
|
||||
[how to](https://help.github.com/articles/using-pull-requests/).
|
||||
|
||||
TensorFlow team members will be assigned to review your pull requests. Once the
|
||||
pull requests are approved and pass continuous integration checks, a TensorFlow
|
||||
|
98
README.md
98
README.md
@ -2,61 +2,58 @@
|
||||
<img src="https://www.tensorflow.org/images/tf_logo_social.png">
|
||||
</div>
|
||||
|
||||
-----------------
|
||||
|
||||
|
||||
| **`Documentation`** |
|
||||
|-----------------|
|
||||
| [](https://www.tensorflow.org/api_docs/) |
|
||||
|
||||
**TensorFlow** is an open source software library for numerical computation
|
||||
using data flow graphs. The graph nodes represent mathematical operations, while
|
||||
the graph edges represent the multidimensional data arrays (tensors) that flow
|
||||
between them. This flexible architecture enables you to deploy computation to
|
||||
one or more CPUs or GPUs in a desktop, server, or mobile device without
|
||||
rewriting code. TensorFlow also includes
|
||||
[TensorBoard](https://github.com/tensorflow/tensorboard), a data visualization
|
||||
toolkit.
|
||||
[TensorFlow](https://www.tensorflow.org/) is an end-to-end open source platform
|
||||
for machine learning. It has a comprehensive, flexible ecosystem of
|
||||
[tools](https://www.tensorflow.org/resources/tools),
|
||||
[libraries](https://www.tensorflow.org/resources/libraries-extensions), and
|
||||
[community](https://www.tensorflow.org/community) resources that lets
|
||||
researchers push the state-of-the-art in ML and developers easily build and
|
||||
deploy ML powered applications.
|
||||
|
||||
TensorFlow was originally developed by researchers and engineers
|
||||
working on the Google Brain team within Google's Machine Intelligence Research
|
||||
organization for the purposes of conducting machine learning and deep neural
|
||||
networks research. The system is general enough to be applicable in a wide
|
||||
variety of other domains, as well.
|
||||
TensorFlow was originally developed by researchers and engineers working on the
|
||||
Google Brain team within Google's Machine Intelligence Research organization for
|
||||
the purposes of conducting machine learning and deep neural networks research.
|
||||
The system is general enough to be applicable in a wide variety of other
|
||||
domains, as well.
|
||||
|
||||
TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards
|
||||
compatible API's for C++, Go, Java, JavaScript, and Swift.
|
||||
TensorFlow provides stable [Python](https://www.tensorflow.org/api_docs/python)
|
||||
and [C++](https://www.tensorflow.org/api_docs/cc) APIs, as well as
|
||||
non-guaranteed backwards compatible API for
|
||||
[other languages](https://www.tensorflow.org/api_docs).
|
||||
|
||||
Keep up to date with release announcements and security updates by
|
||||
subscribing to
|
||||
Keep up-to-date with release announcements and security updates by subscribing
|
||||
to
|
||||
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
|
||||
See all the [mailing lists](https://www.tensorflow.org/community/forums).
|
||||
|
||||
## Installation
|
||||
## Install
|
||||
|
||||
See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
|
||||
[pip package](https://www.tensorflow.org/install/pip), to
|
||||
[enable GPU support](https://www.tensorflow.org/install/gpu), use a
|
||||
[Docker container](https://www.tensorflow.org/install/docker), and
|
||||
[build from source](https://www.tensorflow.org/install/source).
|
||||
|
||||
To install the current release for CPU-only:
|
||||
|
||||
```
|
||||
pip install tensorflow
|
||||
$ pip install tensorflow
|
||||
```
|
||||
|
||||
Use the GPU package for CUDA-enabled GPU cards:
|
||||
Use the GPU package for
|
||||
[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu):
|
||||
|
||||
```
|
||||
pip install tensorflow-gpu
|
||||
$ pip install tensorflow-gpu
|
||||
```
|
||||
|
||||
*See [Installing TensorFlow](https://www.tensorflow.org/install) for detailed
|
||||
instructions, and how to build from source.*
|
||||
|
||||
People who are a little more adventurous can also try our nightly binaries:
|
||||
|
||||
**Nightly pip packages** * We are pleased to announce that TensorFlow now offers
|
||||
nightly pip packages under the
|
||||
*Nightly binaries are available for testing using the
|
||||
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on PyPi.
|
||||
Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean
|
||||
environment to install the nightly TensorFlow build. We support CPU and GPU
|
||||
packages on Linux, Mac, and Windows.
|
||||
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) packages on PyPi.*
|
||||
|
||||
#### *Try your first TensorFlow program*
|
||||
|
||||
@ -74,8 +71,8 @@ $ python
|
||||
'Hello, TensorFlow!'
|
||||
```
|
||||
|
||||
Learn more examples about how to do specific tasks in TensorFlow at the
|
||||
[tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/).
|
||||
For more examples, see the
|
||||
[TensorFlow tutorials](https://www.tensorflow.org/tutorials/).
|
||||
|
||||
## Contribution guidelines
|
||||
|
||||
@ -116,6 +113,8 @@ The TensorFlow project strives to abide by generally accepted best practices in
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
--------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux AMD ROCm GPU** Nightly | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
|
||||
**Linux AMD ROCm GPU** Stable Release | [](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | [Release](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/)
|
||||
**Linux s390x** Nightly | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
|
||||
**Linux s390x CPU** Stable Release | [](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
|
||||
**Linux ppc64le CPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
|
||||
@ -126,20 +125,23 @@ Build Type
|
||||
**Linux CPU with Intel® MKL-DNN** <br> **Supports Python 2.7, 3.4, 3.5, and 3.6** | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/)
|
||||
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 pypi](https://tensorflow.pypi.thoth-station.ninja/index/)
|
||||
|
||||
## For more information
|
||||
## Resources
|
||||
|
||||
* [TensorFlow Website](https://www.tensorflow.org)
|
||||
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
|
||||
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
|
||||
* [TensorFlow.org](https://www.tensorflow.org)
|
||||
* [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)
|
||||
* [TensorFlow official models](https://github.com/tensorflow/models/tree/master/official)
|
||||
* [TensorFlow examples](https://github.com/tensorflow/examples)
|
||||
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [TensorFlow blog](https://blog.tensorflow.org)
|
||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
|
||||
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
|
||||
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
|
||||
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
||||
* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard)
|
||||
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
||||
* [TensorFlow roadmap](https://www.tensorflow.org/community/roadmap)
|
||||
* [TensorFlow white papers](https://www.tensorflow.org/about/bib)
|
||||
* [TensorBoard visualization toolkit](https://github.com/tensorflow/tensorboard)
|
||||
|
||||
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.
|
||||
Learn more about the
|
||||
[TensorFlow community](https://www.tensorflow.org/community) and how to
|
||||
[contribute](https://www.tensorflow.org/community/contribute).
|
||||
|
||||
## License
|
||||
|
||||
|
@ -43,6 +43,11 @@
|
||||
* Transitive dependencies on :pooling_ops were removed. Some users may need to
|
||||
add explicit dependencies on :pooling_ops if they reference the operators
|
||||
from that library.
|
||||
* tf.keras.optimizers default learning rate changes:
|
||||
* Adadelta: 1.000 to 0.001
|
||||
* Adagrad: 0.01 to 0.001
|
||||
* Adamax: 0.002 to 0.001
|
||||
* NAdam: 0.002 to 0.001
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
@ -746,7 +751,7 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A
|
||||
and [programmers guide page](http://tensorflow.org/versions/r1.9/programmers_guide/keras).
|
||||
* Update `tf.keras` to the Keras 2.1.6 API.
|
||||
* Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082).
|
||||
* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees).
|
||||
* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/r1/boosted_trees).
|
||||
* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/lite)
|
||||
for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/toco/README.md)
|
||||
has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again
|
||||
|
31
WORKSPACE
31
WORKSPACE
@ -7,7 +7,7 @@ http_archive(
|
||||
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
||||
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
|
||||
urls = [
|
||||
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
|
||||
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
|
||||
],
|
||||
)
|
||||
@ -49,9 +49,14 @@ remote_config_workspace()
|
||||
# Apple and Swift rules.
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
sha256 = "23792cd999f97fc97284d1c44cb1324bfdd0bc54aa68ad513fa3705aca3b1f9e",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.15.0/rules_apple.0.15.0.tar.gz"],
|
||||
sha256 = "6efdde60c91724a2be7f89b0c0a64f01138a45e63ba5add2dca2645d981d23a1",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.17.2/rules_apple.0.17.2.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_apple/releases
|
||||
http_archive(
|
||||
name = "build_bazel_rules_swift",
|
||||
sha256 = "96a86afcbdab215f8363e65a10cf023b752e90b23abf02272c4fc668fcb70311",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.11.1/rules_swift.0.11.1.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_swift/releases
|
||||
http_archive(
|
||||
name = "build_bazel_apple_support",
|
||||
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
|
||||
@ -62,11 +67,6 @@ http_archive(
|
||||
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/bazel-skylib/releases
|
||||
http_archive(
|
||||
name = "build_bazel_rules_swift",
|
||||
sha256 = "9efe9699e9765e6b4a5e063e4a08f6b163cccaf0443f775d935baf5c3cd6ed0e",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.9.0/rules_swift.0.9.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_swift/releases
|
||||
http_archive(
|
||||
name = "com_github_apple_swift_swift_protobuf",
|
||||
type = "zip",
|
||||
@ -104,8 +104,7 @@ http_archive(
|
||||
build_file = "//:models.BUILD",
|
||||
sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105",
|
||||
urls = [
|
||||
"http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
|
||||
"http://download.tensorflow.org/models/inception_v1.zip",
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
|
||||
],
|
||||
)
|
||||
|
||||
@ -114,8 +113,7 @@ http_archive(
|
||||
build_file = "//:models.BUILD",
|
||||
sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
|
||||
urls = [
|
||||
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
|
||||
"http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
|
||||
],
|
||||
)
|
||||
|
||||
@ -124,8 +122,7 @@ http_archive(
|
||||
build_file = "//:models.BUILD",
|
||||
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
|
||||
urls = [
|
||||
"http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
|
||||
"http://download.tensorflow.org/models/mobile_multibox_v1a.zip",
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
|
||||
],
|
||||
)
|
||||
|
||||
@ -134,8 +131,7 @@ http_archive(
|
||||
build_file = "//:models.BUILD",
|
||||
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa",
|
||||
urls = [
|
||||
"http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
|
||||
"http://download.tensorflow.org/models/stylize_v1.zip",
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
|
||||
],
|
||||
)
|
||||
|
||||
@ -144,7 +140,6 @@ http_archive(
|
||||
build_file = "//:models.BUILD",
|
||||
sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
|
||||
urls = [
|
||||
"http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
"http://download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||
],
|
||||
)
|
||||
|
77
configure.py
77
configure.py
@ -1145,78 +1145,6 @@ def set_trisycl_include_dir(environ_cp):
|
||||
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
|
||||
|
||||
|
||||
def set_mpi_home(environ_cp):
|
||||
"""Set MPI_HOME."""
|
||||
|
||||
default_mpi_home = which('mpirun') or which('mpiexec') or ''
|
||||
default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home))
|
||||
|
||||
def valid_mpi_path(mpi_home):
|
||||
exists = (
|
||||
os.path.exists(os.path.join(mpi_home, 'include')) and
|
||||
(os.path.exists(os.path.join(mpi_home, 'lib')) or
|
||||
os.path.exists(os.path.join(mpi_home, 'lib64')) or
|
||||
os.path.exists(os.path.join(mpi_home, 'lib32'))))
|
||||
if not exists:
|
||||
print(
|
||||
'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found'
|
||||
% (os.path.join(mpi_home, 'include'),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib')),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib64')),
|
||||
os.path.exists(os.path.join(mpi_home, 'lib32'))))
|
||||
return exists
|
||||
|
||||
_ = prompt_loop_or_load_from_env(
|
||||
environ_cp,
|
||||
var_name='MPI_HOME',
|
||||
var_default=default_mpi_home,
|
||||
ask_for_var='Please specify the MPI toolkit folder.',
|
||||
check_success=valid_mpi_path,
|
||||
error_msg='',
|
||||
suppress_default_error=True)
|
||||
|
||||
|
||||
def set_other_mpi_vars(environ_cp):
|
||||
"""Set other MPI related variables."""
|
||||
# Link the MPI header files
|
||||
mpi_home = environ_cp.get('MPI_HOME')
|
||||
symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h')
|
||||
|
||||
# Determine if we use OpenMPI or MVAPICH, these require different header files
|
||||
# to be included here to make bazel dependency checker happy
|
||||
if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'include/mpi_portable_platform.h'),
|
||||
'third_party/mpi/mpi_portable_platform.h')
|
||||
# TODO(gunan): avoid editing files in configure
|
||||
sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI = False',
|
||||
'MPI_LIB_IS_OPENMPI = True')
|
||||
else:
|
||||
# MVAPICH / MPICH
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h')
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h')
|
||||
# TODO(gunan): avoid editing files in configure
|
||||
sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI = True',
|
||||
'MPI_LIB_IS_OPENMPI = False')
|
||||
|
||||
if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so')
|
||||
elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so')
|
||||
elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')):
|
||||
symlink_force(
|
||||
os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so')
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' %
|
||||
(mpi_home, mpi_home, mpi_home))
|
||||
|
||||
|
||||
def system_specific_test_config(env):
|
||||
"""Add default build and test flags required for TF tests to bazelrc."""
|
||||
write_to_bazelrc('test --flaky_test_attempts=3')
|
||||
@ -1549,11 +1477,6 @@ def main():
|
||||
raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. '
|
||||
'At most 1 GPU platform can be configured.')
|
||||
|
||||
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
|
||||
if environ_cp.get('TF_NEED_MPI') == '1':
|
||||
set_mpi_home(environ_cp)
|
||||
set_other_mpi_vars(environ_cp)
|
||||
|
||||
set_cc_opt_flags(environ_cp)
|
||||
set_system_libs_flag(environ_cp)
|
||||
if is_windows():
|
||||
|
@ -7,7 +7,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"//tensorflow/core/platform:default/build_config.bzl",
|
||||
"tf_additional_binary_deps",
|
||||
)
|
||||
load(
|
||||
@ -356,6 +356,15 @@ config_setting(
|
||||
},
|
||||
)
|
||||
|
||||
# Flag to indicate open source build, .bazelrc always has it set to be true
|
||||
config_setting(
|
||||
name = "oss",
|
||||
define_values = {
|
||||
"open_source_build": "true",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "using_cuda_clang_with_dynamic_build",
|
||||
define_values = {
|
||||
@ -364,11 +373,20 @@ config_setting(
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "build_oss_using_cuda_clang",
|
||||
define_values = {
|
||||
"using_cuda_clang": "true",
|
||||
"open_source_build": "true",
|
||||
},
|
||||
)
|
||||
|
||||
# Setting to use when loading kernels dynamically
|
||||
config_setting(
|
||||
name = "dynamic_loaded_kernels",
|
||||
define_values = {
|
||||
"dynamic_loaded_kernels": "true",
|
||||
"framework_shared_object": "true",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
@ -389,16 +407,18 @@ config_setting(
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "using_rocm_hipcc",
|
||||
name = "build_oss_using_cuda_nvcc",
|
||||
define_values = {
|
||||
"using_rocm_hipcc": "true",
|
||||
"using_cuda_nvcc": "true",
|
||||
"open_source_build": "true",
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_mpi_support",
|
||||
values = {"define": "with_mpi_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
name = "using_rocm_hipcc",
|
||||
define_values = {
|
||||
"using_rocm_hipcc": "true",
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
@ -444,6 +464,7 @@ config_setting(
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
@ -607,6 +628,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/c:version_script.lds",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
],
|
||||
)
|
||||
|
||||
@ -750,8 +772,8 @@ genrule(
|
||||
mkdir $@
|
||||
for f in $(SRCS); do
|
||||
d="$${f%/*}"
|
||||
d="$${d#bazel-out*genfiles/}"
|
||||
d="$${d#*external/eigen_archive/}"
|
||||
d="$${d#bazel-out/*/genfiles/}"
|
||||
d="$${d#bazel-out/*/bin/}"
|
||||
|
||||
if [[ $${d} == *local_config_* ]]; then
|
||||
continue
|
||||
@ -763,6 +785,9 @@ genrule(
|
||||
if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
d="$${d#*external/farmhash_archive/src}"
|
||||
d="$${d#*external/$${extname}/}"
|
||||
fi
|
||||
|
||||
mkdir -p "$@/$${d}"
|
||||
|
@ -27,11 +27,27 @@ import sys as _sys
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
from tensorflow.python.tools import module_util as _module_util
|
||||
from tensorflow.python.platform import tf_logging as _logging
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
# WRAPPER_PLACEHOLDER
|
||||
|
||||
if "dev" in __version__: # pylint: disable=undefined-variable
|
||||
_logging.warning("""
|
||||
|
||||
TensorFlow's `tf-nightly` package will soon be updated to TensorFlow 2.0.
|
||||
|
||||
Please upgrade your code to TensorFlow 2.0:
|
||||
* https://www.tensorflow.org/beta/guide/migration_guide
|
||||
|
||||
Or install the latest stable TensorFlow 1.X release:
|
||||
* `pip install -U "tensorflow==1.*"`
|
||||
|
||||
Otherwise your code may be broken by the change.
|
||||
|
||||
""")
|
||||
|
||||
# Make sure directory containing top level submodules is in
|
||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||
# We're using bitwise, but there's nothing special about that.
|
||||
|
@ -73,7 +73,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_platform",
|
||||
"//tensorflow/core/platform:platform",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
],
|
||||
@ -264,10 +264,10 @@ tf_cuda_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_platform",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -355,6 +355,7 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
@ -467,7 +468,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:math_ops_op_lib",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//tensorflow/core:no_op_op_lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:sendrecv_ops_op_lib",
|
||||
"//tensorflow/core:spectral_ops_op_lib",
|
||||
@ -503,6 +503,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
@ -579,7 +580,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = ["noasan"],
|
||||
tags = ["no_cuda_on_cpu_tap"],
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
# the shared library must be able to use core:framework.
|
||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||
@ -588,10 +589,11 @@ tf_cuda_cc_test(
|
||||
":kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1024,7 +1024,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
|
||||
desc->colocation_constraints.insert(location);
|
||||
}
|
||||
} else {
|
||||
desc->node_builder.Attr(attr_name, attr_value);
|
||||
desc->node_builder.Attr(attr_name, std::move(attr_value));
|
||||
}
|
||||
|
||||
status->status = Status::OK();
|
||||
@ -1045,7 +1045,8 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
|
||||
std::vector<string>(desc->colocation_constraints.begin(),
|
||||
desc->colocation_constraints.end()));
|
||||
}
|
||||
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
|
||||
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
|
||||
/*consume=*/true);
|
||||
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
// Run shape inference function for newly added node.
|
||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
@ -596,7 +598,10 @@ struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
|
||||
TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
|
||||
TF_Status* status) {
|
||||
TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteCheckpointReader(reader);
|
||||
return nullptr;
|
||||
}
|
||||
const auto& m = reader->GetVariableToDataTypeMap();
|
||||
for (auto it = m.begin(); it != m.end(); ++it)
|
||||
reader->variable_list.push_back(it->first);
|
||||
@ -995,3 +1000,170 @@ TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext(
|
||||
<< handle->DebugString();
|
||||
return ret;
|
||||
}
|
||||
|
||||
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
|
||||
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
|
||||
result->num_items = num_items;
|
||||
result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
|
||||
return result;
|
||||
}
|
||||
|
||||
void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
|
||||
const int64_t* dims, int num_dims) {
|
||||
DCHECK(index >= 0 && index < shape_list->num_items);
|
||||
TF_ShapeAndType& shape = shape_list->items[index];
|
||||
DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
|
||||
DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
|
||||
shape.num_dims = num_dims;
|
||||
shape.dims = new int64_t[num_dims];
|
||||
memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
|
||||
}
|
||||
|
||||
void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
|
||||
int index) {
|
||||
DCHECK(index >= 0 && index < shape_list->num_items);
|
||||
TF_ShapeAndType& shape = shape_list->items[index];
|
||||
DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
|
||||
shape.num_dims = -1;
|
||||
shape.dims = nullptr;
|
||||
}
|
||||
|
||||
void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
|
||||
TF_DataType dtype) {
|
||||
DCHECK(index >= 0 && index < shape_list->num_items);
|
||||
TF_ShapeAndType& shape_and_type = shape_list->items[index];
|
||||
shape_and_type.dtype = dtype;
|
||||
}
|
||||
|
||||
void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
|
||||
if (shape_list == nullptr) return;
|
||||
for (size_t i = 0; i < shape_list->num_items; ++i) {
|
||||
delete[] shape_list->items[i].dims;
|
||||
}
|
||||
delete[] shape_list->items;
|
||||
delete shape_list;
|
||||
}
|
||||
|
||||
void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
|
||||
int num_items) {
|
||||
if (shape_list_array == nullptr) return;
|
||||
for (int i = 0; i < num_items; ++i) {
|
||||
TF_DeleteShapeAndTypeList(shape_list_array[i]);
|
||||
}
|
||||
delete[] shape_list_array;
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
} // namespace tensorflow
|
||||
|
||||
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
TF_Tensor** input_tensors,
|
||||
TF_ShapeAndTypeList* input_tensors_as_shapes,
|
||||
TF_ShapeAndTypeList** input_resource_shapes_and_types,
|
||||
TF_ShapeAndTypeList** output_shapes,
|
||||
TF_ShapeAndTypeList*** output_resource_shapes_and_types,
|
||||
TF_Status* status) {
|
||||
using tensorflow::NodeDef;
|
||||
using tensorflow::OpRegistrationData;
|
||||
using tensorflow::Tensor;
|
||||
using tensorflow::shape_inference::DimensionHandle;
|
||||
using tensorflow::shape_inference::InferenceContext;
|
||||
using tensorflow::shape_inference::ShapeAndType;
|
||||
using tensorflow::shape_inference::ShapeHandle;
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
node_def.set_name(tfe_op->operation.Name());
|
||||
node_def.set_op(tfe_op->operation.Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
node_def.add_input("dummy_input");
|
||||
}
|
||||
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
|
||||
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
status->status =
|
||||
tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
// Initialize a input_tensor vector with `nullptr` values.
|
||||
std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
|
||||
// A vector to keep track of newly created `tf::Tensor` objects.
|
||||
std::vector<Tensor> all_input_tensors;
|
||||
// Update the vector with information from `input_tensors` if provided.
|
||||
if (input_tensors != nullptr) {
|
||||
// Note that we take the address of the elements in `all_input_tensors`
|
||||
// below. Allocate enough space so that no reallocation happens, which will
|
||||
// make the pointers invalid.
|
||||
all_input_tensors.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (input_tensors[i] == nullptr) continue;
|
||||
all_input_tensors.emplace_back();
|
||||
Tensor& input_tensor = all_input_tensors.back();
|
||||
status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
|
||||
if (!status->status.ok()) return;
|
||||
input_tensors_vector[i] = &input_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
// Create an inference context with dummy values, which will be updated later.
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def,
|
||||
std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
|
||||
{},
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
|
||||
|
||||
// Set input_shapes.
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
std::vector<DimensionHandle> dims;
|
||||
const TF_ShapeAndType& input_shape = input_shapes->items[i];
|
||||
if (input_shape.num_dims == InferenceContext::kUnknownRank) {
|
||||
c.SetInput(i, c.UnknownShape());
|
||||
continue;
|
||||
}
|
||||
for (int j = 0; j < input_shape.num_dims; ++j) {
|
||||
dims.push_back(c.MakeDim(input_shape.dims[j]));
|
||||
}
|
||||
c.SetInput(i, c.MakeShape(dims));
|
||||
}
|
||||
|
||||
// TODO(bgogul): Handle input_tensors_as_shapes.
|
||||
// TODO(bgogul): Handle input_resource_shapes_and_types.
|
||||
|
||||
status->status = c.construction_status();
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
if (op_reg_data->shape_inference_fn == nullptr) {
|
||||
status->status =
|
||||
InvalidArgument("No shape inference function exists for op '",
|
||||
node_def.op(), "', did you forget to define it?");
|
||||
return;
|
||||
}
|
||||
|
||||
status->status = c.Run(op_reg_data->shape_inference_fn);
|
||||
if (!status->status.ok()) return;
|
||||
|
||||
// Set output_shapes.
|
||||
TF_ShapeAndTypeList* output_shapes_result =
|
||||
TF_NewShapeAndTypeList(c.num_outputs());
|
||||
for (int i = 0; i < c.num_outputs(); ++i) {
|
||||
ShapeHandle shape_handle = c.output(i);
|
||||
TF_ShapeAndType& shape = output_shapes_result->items[i];
|
||||
shape.num_dims = c.Rank(shape_handle);
|
||||
if (shape.num_dims == InferenceContext::kUnknownRank) {
|
||||
shape.dims = nullptr;
|
||||
continue;
|
||||
}
|
||||
shape.dims = new int64_t[shape.num_dims];
|
||||
for (size_t j = 0; j < shape.num_dims; ++j) {
|
||||
shape.dims[j] = c.Value(c.Dim(shape_handle, j));
|
||||
}
|
||||
}
|
||||
if (output_shapes != nullptr) *output_shapes = output_shapes_result;
|
||||
|
||||
// TODO(bgogul): Set output_resource_shapes_and_types.
|
||||
}
|
||||
|
||||
void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
||||
opts->opts.validate_colocation_constraints = enable;
|
||||
}
|
||||
|
@ -343,6 +343,65 @@ TF_CAPI_EXPORT extern TFE_TensorHandle*
|
||||
TFE_ConsumeInputConcreteTensorFromTraceContext(TFE_TraceContext* trace_ctx,
|
||||
unsigned int idx);
|
||||
|
||||
// Information about the shape of a Tensor and its type.
|
||||
struct TF_ShapeAndType {
|
||||
// Number of dimensions. -1 indicates unknown rank.
|
||||
int num_dims;
|
||||
// Array of dimensions. -1 indicates unknown dim.
|
||||
int64_t* dims;
|
||||
// The data type. May be 0 to denote unknown type.
|
||||
TF_DataType dtype;
|
||||
};
|
||||
|
||||
typedef struct TF_ShapeAndType TF_ShapeAndType;
|
||||
|
||||
// A list of TF_ShapeAndType elements..
|
||||
struct TF_ShapeAndTypeList {
|
||||
int num_items;
|
||||
TF_ShapeAndType* items;
|
||||
};
|
||||
typedef struct TF_ShapeAndTypeList TF_ShapeAndTypeList;
|
||||
|
||||
// API for manipulating TF_ShapeAndTypeList objects.
|
||||
//
|
||||
TF_CAPI_EXPORT extern TF_ShapeAndTypeList* TF_NewShapeAndTypeList(
|
||||
int num_shapes);
|
||||
TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetShape(
|
||||
TF_ShapeAndTypeList* shape_list, int index, const int64_t* dims,
|
||||
int num_dims);
|
||||
TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetUnknownShape(
|
||||
TF_ShapeAndTypeList* shape_list, int index);
|
||||
TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetDtype(
|
||||
TF_ShapeAndTypeList* shape_list, int index, TF_DataType dtype);
|
||||
TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeList(
|
||||
TF_ShapeAndTypeList* shape_list);
|
||||
TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeListArray(
|
||||
TF_ShapeAndTypeList** shape_list_array, int num_items);
|
||||
|
||||
// Infer shapes for the given `op`. The arguments mimic the arguments of the
|
||||
// `shape_inference::InferenceContext` constructor. Note the following:
|
||||
// - The inputs of the `op` are not used for shape inference. So, it is
|
||||
// OK to not have the inputs properly set in `op`. See `input_tensors`
|
||||
// if you want shape inference to consider the input tensors of the
|
||||
// op for shape inference.
|
||||
// - The types need not be set in `input_shapes` as it is not used.
|
||||
// - The number of `input_tensors` should be the same as the number of items
|
||||
// in `input_shapes`.
|
||||
//
|
||||
// The results are returned in `output_shapes` and
|
||||
// `output_resource_shapes_and_types`. The caller is responsible for freeing the
|
||||
// memory in these buffers by calling `TF_DeleteShapeAndTypeList`.
|
||||
TF_CAPI_EXPORT extern void TFE_InferShapes(
|
||||
TFE_Op* op, TF_ShapeAndTypeList* input_shapes, TF_Tensor** input_tensors,
|
||||
TF_ShapeAndTypeList* input_tensor_as_shapes,
|
||||
TF_ShapeAndTypeList** input_resource_shapes_and_types,
|
||||
TF_ShapeAndTypeList** output_shapes,
|
||||
TF_ShapeAndTypeList*** output_resource_shapes_and_types, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
||||
TF_ImportGraphDefOptions* opts, unsigned char enable);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
@ -431,5 +433,155 @@ TEST_F(AddEagerOpToGraphTest,
|
||||
TFE_DeleteTensorHandle(matrix);
|
||||
}
|
||||
|
||||
class ShapeInferenceTest : public ::testing::Test {
|
||||
protected:
|
||||
ShapeInferenceTest()
|
||||
: status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) {
|
||||
tfe_context_ = TFE_NewContext(tfe_context_options_, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
}
|
||||
|
||||
~ShapeInferenceTest() override {
|
||||
TFE_DeleteContextOptions(tfe_context_options_);
|
||||
TFE_DeleteContext(tfe_context_);
|
||||
TF_DeleteStatus(status_);
|
||||
}
|
||||
|
||||
// Checks the expected result of shape inference for the given `op`.
|
||||
void CheckOutputShapes(
|
||||
TFE_Op* op,
|
||||
const std::vector<absl::optional<std::vector<int64_t>>>& input_shapes_vec,
|
||||
const std::vector<TF_Tensor*>& input_tensors,
|
||||
const absl::optional<std::vector<int64_t>>& expected_shape) {
|
||||
// Create input_shapes.
|
||||
TF_ShapeAndTypeList* input_shapes =
|
||||
TF_NewShapeAndTypeList(input_shapes_vec.size());
|
||||
for (size_t i = 0; i < input_shapes_vec.size(); ++i) {
|
||||
const auto& input_shape = input_shapes_vec[i];
|
||||
if (input_shape.has_value()) {
|
||||
TF_ShapeAndTypeListSetShape(input_shapes, i, input_shape->data(),
|
||||
input_shape->size());
|
||||
} else {
|
||||
TF_ShapeAndTypeListSetUnknownShape(input_shapes, i);
|
||||
}
|
||||
}
|
||||
TF_ShapeAndTypeList* output_shapes;
|
||||
TFE_InferShapes(op, input_shapes,
|
||||
input_tensors.empty()
|
||||
? nullptr
|
||||
: const_cast<TF_Tensor**>(input_tensors.data()),
|
||||
/*input_tensors_as_shapes*/ nullptr,
|
||||
/*input_resource_shapes_and_types*/ nullptr, &output_shapes,
|
||||
/*output_resource_shapes_and_types*/ nullptr, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CHECK_EQ(output_shapes->num_items, 1);
|
||||
|
||||
int num_dims = output_shapes->items[0].num_dims;
|
||||
int64_t* dims = output_shapes->items[0].dims;
|
||||
|
||||
if (!expected_shape.has_value()) {
|
||||
EXPECT_EQ(num_dims, -1);
|
||||
EXPECT_EQ(dims, nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
EXPECT_EQ(num_dims, expected_shape->size());
|
||||
for (size_t i = 0; i < num_dims; ++i) {
|
||||
EXPECT_EQ(dims[i], (*expected_shape)[i]);
|
||||
}
|
||||
TF_DeleteShapeAndTypeList(input_shapes);
|
||||
TF_DeleteShapeAndTypeList(output_shapes);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<int64_t>> make_shape(
|
||||
std::vector<int64_t>&& dims) const {
|
||||
return absl::make_optional(dims);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<int64_t>> unknown_shape() const {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
static constexpr int64_t kUnknownDim =
|
||||
shape_inference::InferenceContext::kUnknownDim;
|
||||
TF_Status* status_;
|
||||
TFE_ContextOptions* tfe_context_options_;
|
||||
TFE_Context* tfe_context_;
|
||||
};
|
||||
|
||||
TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) {
|
||||
TFE_Op* matmul_op;
|
||||
matmul_op = TFE_NewOp(tfe_context_, "MatMul", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
|
||||
// Infer shape when everything is known.
|
||||
CheckOutputShapes(matmul_op,
|
||||
/*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})},
|
||||
/*input_tensors*/ {},
|
||||
/*expected_shape*/ make_shape({3, 4}));
|
||||
|
||||
// Infer shape when second operand has unknown shape.
|
||||
CheckOutputShapes(matmul_op,
|
||||
/*input_shapes*/ {make_shape({3, 2}), unknown_shape()},
|
||||
/*input_tensors*/ {},
|
||||
/*expected_shape*/ make_shape({3, kUnknownDim}));
|
||||
|
||||
// Infer shape when some dimensions are unknown.
|
||||
CheckOutputShapes(
|
||||
matmul_op,
|
||||
/*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})},
|
||||
/*input_tensors*/ {},
|
||||
/*expected_shape*/ make_shape({kUnknownDim, 4}));
|
||||
|
||||
// Infer shape when everything is unknown.
|
||||
CheckOutputShapes(matmul_op,
|
||||
/*input_shapes*/ {unknown_shape(), unknown_shape()},
|
||||
/*input_tensors*/ {},
|
||||
/*expected_shape*/ make_shape({kUnknownDim, kUnknownDim}));
|
||||
|
||||
TFE_DeleteOp(matmul_op);
|
||||
// TODO(bgogul): Add some death tests where status is not OK.
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
||||
// Prepare some tensors for shape.
|
||||
TF_Tensor* tensor_1X6 = Int32Tensor({1, 6});
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TF_Tensor* tensor_1X1X6 = Int32Tensor({1, 1, 6});
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
|
||||
TFE_Op* reshape_op = TFE_NewOp(tfe_context_, "Reshape", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(reshape_op, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrType(reshape_op, "Tshape", TF_INT32);
|
||||
CheckOutputShapes(reshape_op,
|
||||
/* input_shapes*/ {unknown_shape(), unknown_shape()},
|
||||
/* input_tensors*/ {nullptr, tensor_1X6},
|
||||
/*expected_shape*/ make_shape({1, 6}));
|
||||
TFE_DeleteOp(reshape_op);
|
||||
reshape_op = nullptr;
|
||||
|
||||
TFE_Op* fill_op = TFE_NewOp(tfe_context_, "Fill", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(fill_op, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
|
||||
|
||||
float five = 5.0;
|
||||
TFE_TensorHandle* scalar = TestScalarTensorHandle(five);
|
||||
TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CheckOutputShapes(fill_op,
|
||||
/* input_shapes*/ {unknown_shape(), unknown_shape()},
|
||||
/* input_tensors*/ {tensor_1X1X6, scalarTensor},
|
||||
/*expected_shape*/ make_shape({1, 1, 6}));
|
||||
TFE_DeleteOp(fill_op);
|
||||
fill_op = nullptr;
|
||||
|
||||
TFE_DeleteTensorHandle(scalar);
|
||||
TF_DeleteTensor(scalarTensor);
|
||||
TF_DeleteTensor(tensor_1X1X6);
|
||||
TF_DeleteTensor(tensor_1X6);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -41,6 +41,7 @@ namespace {
|
||||
// node names, so if necessary we add a suffix to make
|
||||
// names unique. If we have an input named "A" and a node in the function
|
||||
// body named "a", they will be renamed to "a" and "a_0".
|
||||
// TODO(b/139886381) Unify this and the one in graph_to_functiondef.cc
|
||||
class NodeNameMapping {
|
||||
public:
|
||||
NodeNameMapping() = default;
|
||||
@ -64,14 +65,14 @@ class NodeNameMapping {
|
||||
string Lookup(const string& name) const;
|
||||
|
||||
private:
|
||||
string UniquifyHelper(const string& name) const;
|
||||
string UniquifyHelper(const string& name);
|
||||
static string Normalize(string name);
|
||||
|
||||
// The normalized/uniquified names already used as
|
||||
// input names (in signature), output names (in signature), and node names
|
||||
// (in node_def).
|
||||
// This is a superset of values in name_mapping_.
|
||||
std::unordered_set<string> used_names_;
|
||||
std::unordered_map<string, uint64> used_names_;
|
||||
// Mapping from original node name from the graph to the normalized
|
||||
// and uniquified version of it.
|
||||
std::unordered_map<string, string> name_mapping_;
|
||||
@ -102,13 +103,16 @@ string NodeNameMapping::Normalize(string name) {
|
||||
return i == n ? "unknown" : name.substr(i);
|
||||
}
|
||||
|
||||
string NodeNameMapping::UniquifyHelper(const string& name) const {
|
||||
string NodeNameMapping::UniquifyHelper(const string& name) {
|
||||
auto it = used_names_.emplace(name, 0);
|
||||
// If the name hasn't been used yet, use it as-is.
|
||||
if (used_names_.find(name) == used_names_.end()) return name;
|
||||
if (it.second) return name;
|
||||
|
||||
// Add a suffix to name to make it unique.
|
||||
for (int i = 0;; ++i) {
|
||||
const string candidate = strings::StrCat(name, "_", i);
|
||||
if (used_names_.find(candidate) == used_names_.end()) return candidate;
|
||||
while (true) {
|
||||
const string candidate = strings::StrCat(name, "_", it.first->second);
|
||||
it.first->second++;
|
||||
if (used_names_.emplace(candidate, 0).second) return candidate;
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,16 +124,13 @@ string NodeNameMapping::GetInputName(const string& name) {
|
||||
|
||||
string NodeNameMapping::GetOutputName(const string& name) {
|
||||
const string& input_name = UniquifyHelper(Normalize(name));
|
||||
// Record that we used this name, but don't add it to name_mapping_
|
||||
// since this name is not for a node.
|
||||
used_names_.insert(input_name);
|
||||
// Don't add it to name_mapping_ since this name is not for a node.
|
||||
return input_name;
|
||||
}
|
||||
|
||||
string NodeNameMapping::Uniquify(const string& name) {
|
||||
const string uniqued = UniquifyHelper(name);
|
||||
name_mapping_[name] = uniqued;
|
||||
used_names_.insert(uniqued);
|
||||
return uniqued;
|
||||
}
|
||||
|
||||
@ -139,7 +140,7 @@ Status NodeNameMapping::UseOutputName(const string& name) {
|
||||
return InvalidArgument("Cannot have duplicate output names. Name '", name,
|
||||
"' appears more than once in 'output_names' array.");
|
||||
}
|
||||
used_names_.insert(iter, name);
|
||||
used_names_.emplace(name, 0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -22,15 +22,16 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/cc/saved_model/signature_constants.h"
|
||||
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/example/feature.pb.h"
|
||||
#include "tensorflow/core/framework/api_def.pb.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
@ -233,7 +234,7 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
// Create C++ Tensor
|
||||
Tensor src(tensorflow::DT_STRING, TensorShape(dims));
|
||||
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
|
||||
src.flat<string>()(i) = data[i];
|
||||
src.flat<tstring>()(i) = data[i];
|
||||
}
|
||||
TF_Tensor* dst = TF_TensorFromTensor(src, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -243,7 +244,7 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
ASSERT_EQ(Status::OK(), TF_TensorToTensor(dst, &output)) << line;
|
||||
ASSERT_EQ(src.NumElements(), output.NumElements()) << line;
|
||||
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
|
||||
ASSERT_EQ(data[i], output.flat<string>()(i)) << line;
|
||||
ASSERT_EQ(data[i], output.flat<tstring>()(i)) << line;
|
||||
}
|
||||
|
||||
TF_DeleteTensor(dst);
|
||||
@ -556,7 +557,7 @@ TEST(CAPI, Graph) {
|
||||
EXPECT_FALSE(found_add);
|
||||
found_add = true;
|
||||
} else {
|
||||
ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n);
|
||||
ADD_FAILURE() << "Unexpected NodeDef: " << n.DebugString();
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(found_placeholder);
|
||||
@ -581,20 +582,20 @@ TEST(CAPI, Graph) {
|
||||
// Compare with first GraphDef + added NodeDef.
|
||||
NodeDef* added_node = graph_def.add_node();
|
||||
*added_node = node_def;
|
||||
EXPECT_EQ(ProtoDebugString(graph_def), ProtoDebugString(graph_def2));
|
||||
EXPECT_EQ(graph_def.DebugString(), graph_def2.DebugString());
|
||||
|
||||
// Look up some nodes by name.
|
||||
TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg");
|
||||
EXPECT_TRUE(neg == neg2);
|
||||
NodeDef node_def2;
|
||||
ASSERT_TRUE(GetNodeDef(neg2, &node_def2));
|
||||
EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2));
|
||||
EXPECT_EQ(node_def.DebugString(), node_def2.DebugString());
|
||||
|
||||
TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed");
|
||||
EXPECT_TRUE(feed == feed2);
|
||||
ASSERT_TRUE(GetNodeDef(feed, &node_def));
|
||||
ASSERT_TRUE(GetNodeDef(feed2, &node_def2));
|
||||
EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2));
|
||||
EXPECT_EQ(node_def.DebugString(), node_def2.DebugString());
|
||||
|
||||
// Test iterating through the nodes of a graph.
|
||||
found_placeholder = false;
|
||||
@ -618,7 +619,7 @@ TEST(CAPI, Graph) {
|
||||
found_neg = true;
|
||||
} else {
|
||||
ASSERT_TRUE(GetNodeDef(oper, &node_def));
|
||||
ADD_FAILURE() << "Unexpected Node: " << ProtoDebugString(node_def);
|
||||
ADD_FAILURE() << "Unexpected Node: " << node_def.DebugString();
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(found_placeholder);
|
||||
@ -1385,7 +1386,7 @@ TEST(CAPI, SavedModel) {
|
||||
tensorflow::Example example;
|
||||
auto* feature_map = example.mutable_features()->mutable_feature();
|
||||
(*feature_map)["x"].mutable_float_list()->add_value(i);
|
||||
input.flat<string>()(i) = example.SerializeAsString();
|
||||
input.flat<tstring>()(i) = example.SerializeAsString();
|
||||
}
|
||||
|
||||
const tensorflow::string input_op_name(
|
||||
@ -2498,6 +2499,38 @@ TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) {
|
||||
|
||||
#undef EXPECT_TF_META
|
||||
|
||||
TEST(CAPI, TestTensorAligned) {
|
||||
int64_t dim = 7;
|
||||
size_t tensor_size_bytes = dim * TF_DataTypeSize(TF_FLOAT);
|
||||
TF_Tensor* a = TF_AllocateTensor(
|
||||
/*dtype=*/TF_FLOAT, /*dims=*/&dim, /*num_dims=*/1,
|
||||
/*len=*/tensor_size_bytes);
|
||||
float* data = reinterpret_cast<float*>(TF_TensorData(a));
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
data[i] = 0;
|
||||
}
|
||||
if (EIGEN_MAX_ALIGN_BYTES > 0) {
|
||||
EXPECT_TRUE(TF_TensorIsAligned(a));
|
||||
}
|
||||
TF_DeleteTensor(a);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTensorIsNotAligned) {
|
||||
// Test unaligned access via a Slice.
|
||||
Tensor x(DT_FLOAT, TensorShape({30}));
|
||||
x.flat<float>().setConstant(0.0);
|
||||
|
||||
// Take an unaligned slice.
|
||||
Tensor y = x.Slice(1, 13);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* a = TF_TensorFromTensor(y, status);
|
||||
if (EIGEN_MAX_ALIGN_BYTES > 0) {
|
||||
EXPECT_FALSE(TF_TensorIsAligned(a));
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
TF_DeleteTensor(a);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -8,12 +8,12 @@ load(
|
||||
"tfe_xla_copts",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"//tensorflow/core/platform:default/build_config.bzl",
|
||||
"tf_additional_device_tracer_test_flags",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"//tensorflow/core/platform:default/build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
|
||||
@ -156,6 +156,7 @@ tf_cuda_cc_test(
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_test_util",
|
||||
@ -235,9 +236,11 @@ tf_cuda_cc_test(
|
||||
],
|
||||
args =
|
||||
["--heap_check=local"] + tf_additional_device_tracer_test_flags(),
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_test_util",
|
||||
|
@ -202,9 +202,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
"Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
|
||||
}
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
|
||||
tensorflow::uint64 context_id = tensorflow::random::New64();
|
||||
tensorflow::uint64 context_id = tensorflow::EagerContext::NewContextId();
|
||||
// Make master eager context accessible by local eager service, which might
|
||||
// receive send tensor requests from remote workers.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService(
|
||||
context_id, ctx->context));
|
||||
|
||||
std::vector<string> remote_workers;
|
||||
grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
|
||||
@ -240,9 +242,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
&remote_eager_workers));
|
||||
|
||||
// Initialize remote eager workers.
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, context_id, keep_alive_secs, server_def,
|
||||
remote_eager_workers.get(), ctx->context->Async(), base_request));
|
||||
// TODO(b/138847548) Create remote eager contexts in async mode by default.
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
CreateRemoteContexts(remote_workers, context_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(),
|
||||
ctx->context->Executor()->Async(), base_request));
|
||||
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
@ -261,15 +265,21 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
auto remote_mgr =
|
||||
absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true);
|
||||
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
||||
/*is_master=*/true, ctx->context);
|
||||
|
||||
return ctx->context->InitializeRemoteMaster(
|
||||
LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster(
|
||||
std::move(server), grpc_server->worker_env(), worker_session,
|
||||
std::move(remote_eager_workers), std::move(remote_device_mgr),
|
||||
remote_workers, context_id, r, device_mgr, keep_alive_secs,
|
||||
worker_session->cluster_flr.get(), std::move(remote_mgr));
|
||||
worker_session->cluster_flr.get(), std::move(remote_mgr)));
|
||||
|
||||
// NOTE: We start the server after all other initialization, because the
|
||||
// GrpcServer cannot be destroyed after it is started.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
@ -365,12 +375,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
options->device_placement_policy = policy;
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status) {
|
||||
status->status = ctx->context->SetAsyncForThread(enable);
|
||||
}
|
||||
|
||||
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
@ -455,18 +459,6 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
ctx->context->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
|
||||
status->status = ctx->context->AsyncWait();
|
||||
}
|
||||
|
||||
void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
|
||||
status->status = ctx->context->GetStatus();
|
||||
}
|
||||
|
||||
void TFE_ContextAsyncClearError(TFE_Context* ctx) {
|
||||
ctx->context->ClearAsyncError();
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
@ -571,7 +563,8 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
status->status = EagerCopyToDevice(
|
||||
handle, handle->Context(), handle->Context()->HostCPU(), false, &h_cpu);
|
||||
handle, handle->Context(), handle->Context()->Executor(),
|
||||
handle->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -671,7 +664,7 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
unsigned char* is_list, TF_Status* status) {
|
||||
TF_AttrType ret;
|
||||
TF_AttrType ret = TF_ATTR_INT;
|
||||
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
|
||||
attr_name, &ret, is_list);
|
||||
return ret;
|
||||
@ -683,10 +676,11 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
TF_AttrType ret;
|
||||
TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
|
||||
if (!status->status.ok()) {
|
||||
return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
|
||||
if (status->status.ok()) {
|
||||
ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
|
||||
} else {
|
||||
ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
|
||||
}
|
||||
ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
|
||||
TFE_DeleteOp(op);
|
||||
return ret;
|
||||
}
|
||||
@ -922,6 +916,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
return nullptr;
|
||||
}
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
ctx->context->Executor(),
|
||||
device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
@ -957,12 +952,10 @@ unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
|
||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(true);
|
||||
ctx->context->SetShouldStoreStepStats(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
ctx->context->SetShouldStoreGraphs(false);
|
||||
ctx->context->SetShouldStoreStepStats(false);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@ -974,7 +967,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
status->status = ctx->context->Executor()->WaitForAllPendingNodes();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
|
||||
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
|
||||
|
29
tensorflow/c/eager/c_api.h
Executable file → Normal file
29
tensorflow/c/eager/c_api.h
Executable file → Normal file
@ -77,7 +77,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
|
||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
||||
|
||||
// Sets the default execution mode (sync/async). Note that this can be
|
||||
// overridden per thread using TFE_ContextSetAsyncForThread.
|
||||
// overridden per thread using TFE_ContextSetExecutorForThread.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
|
||||
unsigned char enable);
|
||||
|
||||
@ -89,6 +89,9 @@ TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
|
||||
|
||||
// "Context" under which operations/functions are executed. It encapsulates
|
||||
// things like the available devices, resource manager etc.
|
||||
// TFE_Context must outlive all tensor handles created using it. In other
|
||||
// words, TFE_DeleteContext() must be called after all tensor handles have
|
||||
// been deleted (with TFE_DeleteTensorHandle).
|
||||
//
|
||||
// TODO(ashankar): Merge with TF_Session?
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
@ -115,11 +118,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy
|
||||
TFE_ContextGetDevicePlacementPolicy(TFE_Context* ctx);
|
||||
|
||||
// Overrides the execution mode (sync/async) for the current thread.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
// A tensorflow.ServerDef specifies remote workers (in addition to the current
|
||||
// workers name). Operations created on this context can then be executed on
|
||||
// any of these remote workers by setting an appropriate device.
|
||||
@ -132,25 +130,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
// Causes the calling thread to block till all ops dispatched in async mode
|
||||
// have been executed. Note that "execution" here refers to kernel execution /
|
||||
// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
|
||||
// that lower level device queues (like GPU streams) have been flushed.
|
||||
//
|
||||
// This call may not block for execution of ops enqueued concurrently with this
|
||||
// call.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*,
|
||||
TF_Status* status);
|
||||
|
||||
// When an error happens, any pending operations are discarded and newly issued
|
||||
// ops return an error. This call clears the error state and re-enables
|
||||
// execution of newly issued ops.
|
||||
//
|
||||
// Note that outputs of discarded ops remain in a corrupt state and should not
|
||||
// be used for future calls.
|
||||
// TODO(agarwal): mark the affected handles and raise errors if they are used.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*);
|
||||
|
||||
// A handle to a tensor on a device.
|
||||
//
|
||||
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
|
||||
|
@ -32,9 +32,7 @@ void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
op->operation.ConsumeInput(h->handle);
|
||||
}
|
||||
|
||||
TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx) {
|
||||
return new TFE_Profiler(ctx);
|
||||
}
|
||||
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
|
||||
|
||||
bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
|
||||
return profiler->profiler->Status().ok();
|
||||
@ -55,23 +53,10 @@ void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
|
||||
};
|
||||
}
|
||||
|
||||
TFE_ProfilerContext* TFE_NewProfilerContext() {
|
||||
return new TFE_ProfilerContext;
|
||||
}
|
||||
|
||||
void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context,
|
||||
TFE_Context* eager_context) {
|
||||
profiler_context->profiler_context.eager_context = eager_context->context;
|
||||
}
|
||||
|
||||
void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) {
|
||||
delete profiler_context;
|
||||
}
|
||||
|
||||
void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) {
|
||||
// Release child thread intentionally. The child thread can be terminate by
|
||||
void TFE_StartProfilerServer(int port) {
|
||||
// Release child thread intentionally. The child thread can be terminated by
|
||||
// terminating the main thread.
|
||||
tensorflow::StartProfilerServer(&context->profiler_context, port).release();
|
||||
tensorflow::StartProfilerServer(port).release();
|
||||
}
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
@ -587,3 +572,30 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
op->operation.SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
||||
return new TFE_Executor(is_async);
|
||||
}
|
||||
|
||||
void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
|
||||
|
||||
bool TFE_ExecutorIsAsync(TFE_Executor* executor) {
|
||||
return executor->executor()->Async();
|
||||
}
|
||||
|
||||
void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor,
|
||||
TF_Status* status) {
|
||||
status->status = executor->executor()->WaitForAllPendingNodes();
|
||||
}
|
||||
|
||||
void TFE_ExecutorClearError(TFE_Executor* executor) {
|
||||
executor->executor()->ClearError();
|
||||
}
|
||||
|
||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
ctx->context->SetExecutorForThread(executor->executor());
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
return new TFE_Executor(ctx->context->Executor());
|
||||
}
|
||||
|
@ -25,8 +25,6 @@ extern "C" {
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
typedef struct TFE_ProfilerContext TFE_ProfilerContext;
|
||||
|
||||
// A profiler which will start profiling when creating the object and will stop
|
||||
// when the object is destroyed. It will profile all operations run under the
|
||||
// given TFE_Context. Multiple instance of it can be created, but at most one
|
||||
@ -34,7 +32,7 @@ typedef struct TFE_ProfilerContext TFE_ProfilerContext;
|
||||
// Thread-safety: TFE_Profiler is thread-safe.
|
||||
typedef struct TFE_Profiler TFE_Profiler;
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx);
|
||||
TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler();
|
||||
TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler);
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
|
||||
|
||||
@ -44,27 +42,14 @@ TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
// Return a new profiler context object.
|
||||
TF_CAPI_EXPORT extern TFE_ProfilerContext* TFE_NewProfilerContext(void);
|
||||
|
||||
// Set the eager context in TFE_ProfilerServerOptions
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerContextSetEagerContext(
|
||||
TFE_ProfilerContext* profiler_context, TFE_Context* eager_context);
|
||||
|
||||
// Destroy a profiler context object.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteProfilerContext(
|
||||
TFE_ProfilerContext* profiler_context);
|
||||
|
||||
// Start a profiler grpc server which listens to specified port. It will start
|
||||
// the server on its own thread. It can be shutdown by terminating tensorflow.
|
||||
// It can be used in both Eager mode and graph mode. Creating multiple profiler
|
||||
// server is allowed. The service defined in
|
||||
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture tracable
|
||||
// file following
|
||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_ProfilerContext* context,
|
||||
int port);
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
|
||||
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
|
||||
|
||||
// Enables only graph collection in RunMetadata on the functions executed from
|
||||
// this context.
|
||||
@ -367,6 +352,51 @@ TF_CAPI_EXPORT extern void TFE_OpSetCancellationManager(
|
||||
TFE_Op* op, TFE_CancellationManager* cancellation_manager,
|
||||
TF_Status* status);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Eager Executor APIs.
|
||||
typedef struct TFE_Executor TFE_Executor;
|
||||
|
||||
// Creates a new eager Executor. Nodes in one executor are guaranteed to be
|
||||
// executed in sequence. Assigning nodes to different executors allows executing
|
||||
// nodes in parallel.
|
||||
TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor(bool is_async);
|
||||
|
||||
// Deletes the eager Executor without waiting for enqueued nodes. Please call
|
||||
// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to
|
||||
// make sure all nodes are finished.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteExecutor(TFE_Executor*);
|
||||
|
||||
// Returns true if the executor is in async mode.
|
||||
TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*);
|
||||
|
||||
// Causes the calling thread to block till all ops dispatched in this executor
|
||||
// have been executed. Note that "execution" here refers to kernel execution /
|
||||
// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
|
||||
// that lower level device queues (like GPU streams) have been flushed.
|
||||
//
|
||||
// This call may not block for execution of ops enqueued concurrently with this
|
||||
// call.
|
||||
TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes(
|
||||
TFE_Executor*, TF_Status* status);
|
||||
|
||||
// When an error happens, any pending operations are discarded and newly issued
|
||||
// ops return an error. This call clears the error state and re-enables
|
||||
// execution of newly issued ops.
|
||||
//
|
||||
// Note that outputs of discarded ops remain in a corrupt state and should not
|
||||
// be used for future calls.
|
||||
// TODO(agarwal): mark the affected handles and raise errors if they are used.
|
||||
TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*);
|
||||
|
||||
// Sets a custom Executor for current thread. All nodes created by this thread
|
||||
// will be added to this Executor. It will override current executor.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*,
|
||||
TFE_Executor*);
|
||||
|
||||
// Returns the Executor for current thread.
|
||||
TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread(
|
||||
TFE_Context*);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
||||
@ -43,12 +44,9 @@ void ExecuteWithProfiling(bool async) {
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext();
|
||||
TFE_ProfilerContextSetEagerContext(profiler_context, ctx);
|
||||
TFE_Profiler* profiler = TFE_NewProfiler(profiler_context);
|
||||
TFE_Profiler* profiler = TFE_NewProfiler();
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_DeleteProfilerContext(profiler_context);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
@ -71,8 +69,10 @@ void ExecuteWithProfiling(bool async) {
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Buffer* profiler_result = TF_NewBuffer();
|
||||
if (async) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
|
||||
TFE_DeleteProfiler(profiler);
|
||||
@ -85,7 +85,10 @@ void ExecuteWithProfiling(bool async) {
|
||||
if (!gpu_device_name.empty()) {
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
|
||||
// device name with "stream:all" is collected by Device Tracer.
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
// ROCm platform does not yet support stream level tracing
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
|
||||
#endif
|
||||
}
|
||||
// "/host:CPU" is collected by TraceMe
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
|
||||
@ -110,27 +113,14 @@ TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); }
|
||||
TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); }
|
||||
|
||||
TEST(CAPI, MultipleProfilerSession) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(false));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext();
|
||||
TFE_ProfilerContextSetEagerContext(profiler_context, ctx);
|
||||
|
||||
TFE_Profiler* profiler1 = TFE_NewProfiler(profiler_context);
|
||||
TFE_Profiler* profiler1 = TFE_NewProfiler();
|
||||
EXPECT_TRUE(TFE_ProfilerIsOk(profiler1));
|
||||
|
||||
TFE_Profiler* profiler2 = TFE_NewProfiler(profiler_context);
|
||||
TFE_Profiler* profiler2 = TFE_NewProfiler();
|
||||
EXPECT_FALSE(TFE_ProfilerIsOk(profiler2));
|
||||
|
||||
TFE_DeleteProfiler(profiler1);
|
||||
TFE_DeleteProfiler(profiler2);
|
||||
TFE_DeleteProfilerContext(profiler_context);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringCounter0) {
|
||||
@ -307,5 +297,205 @@ TEST(CAPI, CancellationManager) {
|
||||
TFE_DeleteCancellationManager(c_mgr);
|
||||
}
|
||||
|
||||
TEST(CAPI, Function_ident_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
TF_OperationDescription* arg_descr =
|
||||
TF_NewOperation(function_graph, "Placeholder", "arg");
|
||||
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_OperationDescription* id_descr =
|
||||
TF_NewOperation(function_graph, "Identity", "id");
|
||||
TF_SetAttrType(id_descr, "T", TF_INT32);
|
||||
TF_AddInput(id_descr, {arg, 0});
|
||||
TF_Operation* id = TF_FinishOperation(id_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_Output input{arg, 0};
|
||||
TF_Output output{id, 0};
|
||||
TF_Function* fn =
|
||||
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
|
||||
&output, nullptr, nullptr, "test", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteGraph(function_graph);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_ContextAddFunction(ctx, fn, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteFunction(fn);
|
||||
|
||||
for (bool async : {false, true, false}) {
|
||||
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_Executor* executor = TFE_NewExecutor(async);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t =
|
||||
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
|
||||
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
|
||||
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, h, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
|
||||
std::vector<TFE_TensorHandle*> result;
|
||||
result.push_back(nullptr);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, result.data(), &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(num_retvals, 1);
|
||||
|
||||
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
|
||||
TFE_ContextSetExecutorForThread(ctx, old_executor);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteExecutor(old_executor);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TF_DeleteTensor(r);
|
||||
TFE_DeleteTensorHandle(result[0]);
|
||||
}
|
||||
TFE_ContextRemoveFunction(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
TEST(CAPI, Function_ident_XLA_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
TF_OperationDescription* arg_descr =
|
||||
TF_NewOperation(function_graph, "Placeholder", "arg");
|
||||
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_OperationDescription* id_descr =
|
||||
TF_NewOperation(function_graph, "Identity", "id");
|
||||
TF_SetAttrType(id_descr, "T", TF_INT32);
|
||||
TF_AddInput(id_descr, {arg, 0});
|
||||
TF_Operation* id = TF_FinishOperation(id_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_Output input{arg, 0};
|
||||
TF_Output output{id, 0};
|
||||
TF_Function* fn =
|
||||
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
|
||||
&output, nullptr, nullptr, "test", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteGraph(function_graph);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_ContextAddFunction(ctx, fn, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteFunction(fn);
|
||||
|
||||
for (bool async : {false, true, false}) {
|
||||
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_Executor* executor = TFE_NewExecutor(async);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK);
|
||||
TF_Tensor* t =
|
||||
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
|
||||
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
|
||||
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, h, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
|
||||
// Now run it via XLA.
|
||||
TFE_OpSetXLACompilation(op, true);
|
||||
|
||||
std::vector<TFE_TensorHandle*> result;
|
||||
result.push_back(nullptr);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, result.data(), &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(num_retvals, 1);
|
||||
|
||||
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
|
||||
TFE_ContextSetExecutorForThread(ctx, old_executor);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteExecutor(old_executor);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TF_DeleteTensor(r);
|
||||
TFE_DeleteTensorHandle(result[0]);
|
||||
}
|
||||
TFE_ContextRemoveFunction(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
void Executor_MatMul_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_Executor* executor = TFE_NewExecutor(async);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
|
||||
int num_retvals = 2;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_ContextSetExecutorForThread(ctx, old_executor);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteExecutor(old_executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, Executor_MatMul_CPU) { Executor_MatMul_CPU(false); }
|
||||
TEST(CAPI, Executor_MatMul_CPUAsync) { Executor_MatMul_CPU(true); }
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -76,7 +76,14 @@ struct TFE_Context {
|
||||
async, device_mgr, device_mgr_owned, rendezvous,
|
||||
custom_kernel_creator)) {}
|
||||
|
||||
~TFE_Context() { context->Unref(); }
|
||||
~TFE_Context() {
|
||||
// TODO(iga): Add a separate API method to shutdown TFE_Context so that we
|
||||
// don't send RPCs and block in destructor.
|
||||
context->WaitForAndCloseRemoteContexts();
|
||||
// context->RefCountIsOne() should be true here.
|
||||
// TODO(iga): Remove EagerContext refcounting.
|
||||
context->Unref();
|
||||
}
|
||||
|
||||
tensorflow::EagerContext* context;
|
||||
};
|
||||
@ -130,14 +137,8 @@ struct TFE_Op {
|
||||
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
|
||||
};
|
||||
|
||||
struct TFE_ProfilerContext {
|
||||
tensorflow::ProfilerContext profiler_context;
|
||||
};
|
||||
|
||||
struct TFE_Profiler {
|
||||
explicit TFE_Profiler(TFE_ProfilerContext* ctx) {
|
||||
profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context);
|
||||
}
|
||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
||||
|
||||
std::unique_ptr<tensorflow::ProfilerSession> profiler;
|
||||
};
|
||||
@ -291,4 +292,19 @@ struct TFE_CancellationManager {
|
||||
tensorflow::CancellationManager cancellation_manager;
|
||||
};
|
||||
|
||||
struct TFE_Executor {
|
||||
explicit TFE_Executor(bool async)
|
||||
: owned_executor(new tensorflow::EagerExecutor(async)) {}
|
||||
|
||||
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
|
||||
: owned_executor(nullptr), unowned_executor(executor) {}
|
||||
|
||||
tensorflow::EagerExecutor* executor() {
|
||||
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
@ -78,7 +79,10 @@ void BM_Execute(int iters, int async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
if (async) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteOp(matmul);
|
||||
@ -89,6 +93,41 @@ void BM_Execute(int iters, int async) {
|
||||
}
|
||||
BENCHMARK(BM_Execute)->Arg(0)->Arg(1);
|
||||
|
||||
void BM_Execute_Identity(int iters, int async) {
|
||||
tensorflow::testing::StopTiming();
|
||||
tensorflow::testing::SetLabel(async ? "ExecuteIdentityAsync"
|
||||
: "ExecuteIdentity");
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* identity = IdentityOp(ctx, m);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_Execute(identity, &retvals[0], &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
if (async) {
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteOp(identity);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteContext(ctx);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
BENCHMARK(BM_Execute_Identity)->Arg(0)->Arg(1);
|
||||
|
||||
TEST(CAPI, Context) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -196,8 +235,10 @@ void TestRemoteExecute(bool async) {
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
@ -282,9 +323,11 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
@ -298,7 +341,7 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -324,33 +367,49 @@ void TestRemoteExecuteDeleteTensorAfterContext(bool async) {
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
// Use large matrices so that RPCs don't return before we get a chance
|
||||
// to call TFE_DeleteContext.
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100();
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* h1_task1 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// Delete tensors after context is deleted.
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h0_task1);
|
||||
TFE_DeleteTensorHandle(h1_task1);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) {
|
||||
TestRemoteExecuteDeleteTensorAfterContext(false);
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) {
|
||||
TestRemoteExecuteDeleteTensorAfterContext(true);
|
||||
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
@ -397,8 +456,10 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
@ -433,8 +494,9 @@ void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
@ -476,8 +538,9 @@ void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
|
||||
new_local_device_name);
|
||||
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
@ -610,8 +673,11 @@ void TensorHandleCopyBetweenDevicesError(bool async) {
|
||||
TFE_TensorHandle* hcopy =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_ContextAsyncWait(ctx, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get()));
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteTensorHandle(hcopy);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice);
|
||||
@ -740,8 +806,10 @@ void TensorHandleSilentCopy(bool async) {
|
||||
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_ContextAsyncWait(ctx, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
@ -786,8 +854,10 @@ void TensorHandleSilentCopyLocal(bool async) {
|
||||
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_ContextAsyncWait(ctx, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
|
||||
@ -921,8 +991,10 @@ TEST(CAPI, TensorHandleDevices) {
|
||||
}
|
||||
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_ContextAsyncWait(ctx, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
@ -1000,9 +1072,11 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
|
||||
retvals[0] = nullptr;
|
||||
TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status));
|
||||
TFE_ContextAsyncClearError(ctx);
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorClearError(executor);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
// Following works in async mode since TFE_ContextAsyncClearError was called.
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
@ -1220,147 +1294,6 @@ void ExecuteWithTracing(bool async) {
|
||||
TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); }
|
||||
TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); }
|
||||
|
||||
TEST(CAPI, Function_ident_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
TF_OperationDescription* arg_descr =
|
||||
TF_NewOperation(function_graph, "Placeholder", "arg");
|
||||
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_OperationDescription* id_descr =
|
||||
TF_NewOperation(function_graph, "Identity", "id");
|
||||
TF_SetAttrType(id_descr, "T", TF_INT32);
|
||||
TF_AddInput(id_descr, {arg, 0});
|
||||
TF_Operation* id = TF_FinishOperation(id_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_Output input{arg, 0};
|
||||
TF_Output output{id, 0};
|
||||
TF_Function* fn =
|
||||
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
|
||||
&output, nullptr, nullptr, "test", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteGraph(function_graph);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_ContextAddFunction(ctx, fn, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteFunction(fn);
|
||||
|
||||
for (bool async : {false, true, false}) {
|
||||
TFE_ContextSetAsyncForThread(ctx, static_cast<unsigned char>(async),
|
||||
status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK);
|
||||
TF_Tensor* t =
|
||||
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
|
||||
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
|
||||
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, h, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
|
||||
std::vector<TFE_TensorHandle*> result;
|
||||
result.push_back(nullptr);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, result.data(), &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(num_retvals, 1);
|
||||
|
||||
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TF_DeleteTensor(r);
|
||||
TFE_DeleteTensorHandle(result[0]);
|
||||
}
|
||||
TFE_ContextRemoveFunction(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
TEST(CAPI, Function_ident_XLA_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
TF_OperationDescription* arg_descr =
|
||||
TF_NewOperation(function_graph, "Placeholder", "arg");
|
||||
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_OperationDescription* id_descr =
|
||||
TF_NewOperation(function_graph, "Identity", "id");
|
||||
TF_SetAttrType(id_descr, "T", TF_INT32);
|
||||
TF_AddInput(id_descr, {arg, 0});
|
||||
TF_Operation* id = TF_FinishOperation(id_descr, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_Output input{arg, 0};
|
||||
TF_Output output{id, 0};
|
||||
TF_Function* fn =
|
||||
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
|
||||
&output, nullptr, nullptr, "test", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteGraph(function_graph);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_ContextAddFunction(ctx, fn, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteFunction(fn);
|
||||
|
||||
for (bool async : {false, true, false}) {
|
||||
TFE_ContextSetAsyncForThread(ctx, static_cast<unsigned char>(async),
|
||||
status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK);
|
||||
TF_Tensor* t =
|
||||
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
|
||||
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
|
||||
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, h, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
|
||||
// Now run it via XLA.
|
||||
TFE_OpSetXLACompilation(op, true);
|
||||
|
||||
std::vector<TFE_TensorHandle*> result;
|
||||
result.push_back(nullptr);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, result.data(), &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(num_retvals, 1);
|
||||
|
||||
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TF_DeleteTensor(r);
|
||||
TFE_DeleteTensorHandle(result[0]);
|
||||
}
|
||||
TFE_ContextRemoveFunction(ctx, "ident", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
string MatMulFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
@ -1474,7 +1407,10 @@ void BM_ExecuteFunction(int iters, int async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
if (async) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteTensorHandle(m);
|
||||
|
@ -85,6 +85,24 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100() {
|
||||
constexpr int64_t dims[] = {100, 100};
|
||||
constexpr int num_elements = dims[0] * dims[1];
|
||||
float data[num_elements];
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
data[i] = 1.0f;
|
||||
}
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
|
||||
int64_t dims[] = {3, 2};
|
||||
double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
|
||||
@ -128,6 +146,19 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
return op;
|
||||
}
|
||||
|
||||
TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "Identity", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, a, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
|
||||
|
||||
return op;
|
||||
}
|
||||
|
||||
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
// Return a tensor handle containing a float scalar
|
||||
@ -34,6 +33,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle();
|
||||
// Return a tensor handle containing a 2x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle();
|
||||
|
||||
// Return a tensor handle containing a 100x100 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100();
|
||||
|
||||
// Return a tensor handle containing a 3x2 matrix of doubles
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
|
||||
|
||||
@ -43,6 +45,9 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
||||
// Return a matmul op multiplying `a` by `b`.
|
||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
// Return an identity op.
|
||||
TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a);
|
||||
|
||||
// Return a shape op fetching the shape of `a`.
|
||||
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
// Language-agnostic gradient tape. Does not perform backpropagation, just
|
||||
// maintains the data structures required to do so.
|
||||
|
||||
#include <stack>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -209,7 +210,9 @@ class ForwardAccumulator {
|
||||
// ForwardAccumulator.
|
||||
explicit ForwardAccumulator(
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
|
||||
: vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {}
|
||||
: vspace_(vspace) {
|
||||
call_state_.emplace(nullptr, false);
|
||||
}
|
||||
|
||||
virtual ~ForwardAccumulator() {
|
||||
for (auto accumulated : accumulated_gradients_) {
|
||||
@ -262,6 +265,12 @@ class ForwardAccumulator {
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter);
|
||||
|
||||
// Returns true if `Accumulate` is active somewhere above on the stack and
|
||||
// there isn't an intervening PushState. This is useful for ordering
|
||||
// ForwardAccumulators, where more deeply nested accumulators should not see
|
||||
// computations from less deeply nested accumulators.
|
||||
bool BusyAccumulating() const { return call_state_.top().accumulating; }
|
||||
|
||||
// Fetches the current Jacobian-vector product associated with `tensor_id`, or
|
||||
// a nullptr if none is available.
|
||||
//
|
||||
@ -276,6 +285,15 @@ class ForwardAccumulator {
|
||||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes);
|
||||
|
||||
// Temporarily push or pop transient state for this accumulator.
|
||||
//
|
||||
// Allows an accumulator which is currently processing an operation to
|
||||
// temporarily reset its state. Without pushing and poping, accumulators
|
||||
// ignore operations executed as a direct result of their own jvp
|
||||
// computations.
|
||||
void PushState() { call_state_.emplace(nullptr, false); }
|
||||
void PopState() { call_state_.pop(); }
|
||||
|
||||
private:
|
||||
// Helper for Accumulate: uses a GradientTape to compute forward gradients
|
||||
// from a backward gradient function. Fills `out_grads` corresponding to
|
||||
@ -283,7 +301,7 @@ class ForwardAccumulator {
|
||||
//
|
||||
// Executes the backward function in order to trace its gradient, which will
|
||||
// waste computation if executing eagerly (when graph building the unneeded
|
||||
// computation is pruned). Temporarily sets `backward_tape_` so that
|
||||
// computation is pruned). Temporarily sets `backward_tape` so that
|
||||
// Accumulate will forward op executions to the tape while the backward
|
||||
// function is running; this effectively adds the backward tape to the active
|
||||
// set (but does not require complicated callbacks to the language bindings).
|
||||
@ -299,16 +317,26 @@ class ForwardAccumulator {
|
||||
// Not owned; provides operations on Tensors which are currently only
|
||||
// available in language bindings (e.g. Python).
|
||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
|
||||
// Set temporarily while in the Accumulate method; if backward_tape_ is not
|
||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||
// backward pass on its backward function.
|
||||
//
|
||||
// Not owned by the ForwardAccumulator. The method which sets `backward_tape_`
|
||||
// keeps ownership.
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape_;
|
||||
// While the Accumulate method is running (accumulating_ is True), any op
|
||||
// executions not forwarded to backward_tape_ should be ignored.
|
||||
bool accumulating_;
|
||||
|
||||
struct AccumulatorCallState {
|
||||
AccumulatorCallState(
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
|
||||
bool accumulating)
|
||||
: backward_tape(backward_tape), accumulating(accumulating) {}
|
||||
// Set temporarily while in the Accumulate method; if backward_tape is not
|
||||
// nullptr then we forward op executions to it so Accumulate can compute a
|
||||
// backward pass on its backward function.
|
||||
//
|
||||
// Not owned by the ForwardAccumulator. The method which sets
|
||||
// `backward_tape` keeps ownership.
|
||||
GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
|
||||
// While the Accumulate method is running (accumulating is True), any op
|
||||
// executions not forwarded to backward_tape should be ignored.
|
||||
bool accumulating;
|
||||
};
|
||||
// A deque-backed stack, whose element references are not invalidated by
|
||||
// pushes and pops at the back.
|
||||
std::stack<AccumulatorCallState> call_state_;
|
||||
};
|
||||
|
||||
// Template instantiations here
|
||||
@ -841,12 +869,12 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
|
||||
gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes) {
|
||||
if (backward_tape_ != nullptr) {
|
||||
// If we're forwarding Accumulate calls to backward_tape_'s RecordOperation,
|
||||
if (call_state_.top().backward_tape != nullptr) {
|
||||
// If we're forwarding Accumulate calls to backward_tape's RecordOperation,
|
||||
// we should also delegate ShouldRecord.
|
||||
return backward_tape_->ShouldRecord(tensor_ids, dtypes);
|
||||
return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
|
||||
}
|
||||
if (accumulating_) {
|
||||
if (call_state_.top().accumulating) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < tensor_ids.size(); ++i) {
|
||||
@ -878,9 +906,10 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
*/
|
||||
std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
|
||||
new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
|
||||
backward_tape_ = tape.get();
|
||||
AccumulatorCallState& call_state = call_state_.top();
|
||||
call_state.backward_tape = tape.get();
|
||||
auto pop_backward_tape =
|
||||
gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; });
|
||||
gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
|
||||
std::vector<Gradient*> forwardprop_aids;
|
||||
std::vector<int64> sources;
|
||||
std::unordered_set<int64> sources_set;
|
||||
@ -955,10 +984,10 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
const ForwardFunction<Gradient>* forward_function,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
|
||||
if (backward_tape_ != nullptr) {
|
||||
// If backward_tape_ is not null, then this call to Accumulate is the result
|
||||
if (call_state_.top().backward_tape != nullptr) {
|
||||
// If backward_tape is not null, then this call to Accumulate is the result
|
||||
// of a still-active call to Accumulate which is running operations. We
|
||||
// forward these operations to backward_tape_ so the outer Accumulate call
|
||||
// forward these operations to backward_tape so the outer Accumulate call
|
||||
// can do its work.
|
||||
//
|
||||
// Rather than re-entering and delegating Accumulate like this, we could
|
||||
@ -966,9 +995,9 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
// (so it can deactivate itself and activate its GradientTape). Currently
|
||||
// that is managed by the language binding and would require relatively
|
||||
// messy callbacks.
|
||||
backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id,
|
||||
input_dtypes, backward_function_getter,
|
||||
backward_function_deleter);
|
||||
call_state_.top().backward_tape->RecordOperation(
|
||||
op_type, output_tensors, input_tensor_id, input_dtypes,
|
||||
backward_function_getter, backward_function_deleter);
|
||||
return Status::OK();
|
||||
}
|
||||
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
|
||||
@ -1006,9 +1035,8 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
|
||||
// Avoid infinite recursion. Whichever forward function we run, it'll end up
|
||||
// executing ops, and we don't want to watch those with this accumulator.
|
||||
accumulating_ = true;
|
||||
auto reset_accumulating =
|
||||
gtl::MakeCleanup([this] { this->accumulating_ = false; });
|
||||
call_state_.emplace(nullptr, true);
|
||||
auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
|
||||
|
||||
std::vector<Gradient*> forward_grads;
|
||||
if (forward_function == nullptr) {
|
||||
|
@ -45,6 +45,9 @@ CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||
void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) {
|
||||
if (args.cancellation_manager != nullptr) {
|
||||
VLOG(1) << "WARNING: CRemoteRendezvous does not support cancellation.";
|
||||
}
|
||||
TF_ParsedKey key;
|
||||
key.src_device = parsed.src_device.data();
|
||||
key.src_device_len = parsed.src_device.size();
|
||||
|
@ -63,12 +63,26 @@ cat << EOF > tensorflow.pc
|
||||
prefix=${TF_PREFIX}
|
||||
exec_prefix=\${prefix}
|
||||
libdir=\${exec_prefix}/${LIBDIR}
|
||||
includedir=\${prefix}/include
|
||||
includedir=\${prefix}/include/tensorflow
|
||||
|
||||
Name: TensorFlow
|
||||
Version: ${TF_VERSION}
|
||||
Description: Library for computation using data flow graphs for scalable machine learning
|
||||
Requires:
|
||||
Libs: -L\${libdir} -ltensorflow
|
||||
Libs: -L\${libdir} -ltensorflow -ltensorflow_framework
|
||||
Cflags: -I\${includedir}
|
||||
EOF
|
||||
|
||||
cat << EOF > tensorflow_cc.pc
|
||||
prefix=${TF_PREFIX}
|
||||
exec_prefix=\${prefix}
|
||||
libdir=\${exec_prefix}/${LIBDIR}
|
||||
includedir=\${prefix}/include/tensorflow
|
||||
|
||||
Name: TensorFlow
|
||||
Version: ${TF_VERSION}
|
||||
Description: Library for computation using data flow graphs for scalable machine learning
|
||||
Requires:
|
||||
Libs: -L\${libdir} -ltensorflow_cc -ltensorflow_framework
|
||||
Cflags: -I\${includedir}
|
||||
EOF
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -189,8 +190,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
|
||||
void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor,
|
||||
TF_Status* status) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
if (i < 0 || i >= cc_ctx->num_inputs()) {
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
|
||||
if (i < 0 || i >= cc_ctx->num_outputs()) {
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range");
|
||||
return;
|
||||
}
|
||||
::tensorflow::Tensor cc_tensor;
|
||||
@ -240,3 +241,14 @@ TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
|
||||
int64_t TF_StepId(TF_OpKernelContext* ctx) {
|
||||
return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id();
|
||||
}
|
||||
|
||||
TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
|
||||
TF_DataType dtype, int64_t* dims, int num_dims,
|
||||
size_t len) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
|
||||
tensorflow::AllocatorAttributes attr = cc_ctx->output_alloc_attr(index);
|
||||
auto* allocator = cc_ctx->get_allocator(attr);
|
||||
void* data = tensorflow::allocate_tensor("TF_AllocateOutput", len, allocator);
|
||||
return TF_NewTensor(dtype, dims, num_dims, data, len,
|
||||
tensorflow::deallocate_buffer, allocator);
|
||||
}
|
||||
|
@ -180,6 +180,16 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Allocates Tensor for output at given index. Caller takes ownership of
|
||||
// returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor).
|
||||
//
|
||||
// This function should be used to allocate outputs inside kernel
|
||||
// compute function.
|
||||
TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
|
||||
int index, TF_DataType dtype,
|
||||
int64_t* dims, int num_dims,
|
||||
size_t len);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -12,17 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
#include "tensorflow/c/kernels.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -309,4 +315,144 @@ TEST(TestKernel, TestHostMemory) {
|
||||
TF_DeleteKernelBuilder(builder);
|
||||
ASSERT_TRUE(delete_called);
|
||||
}
|
||||
|
||||
class DeviceKernelOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void SetupOp(const char* op_name, const char* kernel_name,
|
||||
void (*compute_func)(void*, TF_OpKernelContext*)) {
|
||||
TF_KernelBuilder* builder = TF_NewKernelBuilder(
|
||||
op_name, device_name_, nullptr, compute_func, nullptr);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterKernelBuilder(kernel_name, builder, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0"));
|
||||
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
||||
#endif
|
||||
TF_ASSERT_OK(NodeDefBuilder(op_name, op_name).Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
const char* device_name_ = tensorflow::DEVICE_GPU;
|
||||
#else
|
||||
const char* device_name_ = tensorflow::DEVICE_CPU;
|
||||
#endif
|
||||
};
|
||||
|
||||
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
// Allocate output
|
||||
int64_t dim = 1;
|
||||
size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
|
||||
TF_Tensor* output = TF_AllocateOutput(
|
||||
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
|
||||
/*num_dims=*/1, /*len=*/tensor_size_bytes);
|
||||
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
|
||||
EXPECT_EQ(1, TF_NumDims(output));
|
||||
EXPECT_EQ(1, TF_Dim(output, 0));
|
||||
|
||||
// Set output to 3
|
||||
float* data = reinterpret_cast<float*>(TF_TensorData(output));
|
||||
float value = 3.0f;
|
||||
#if GOOGLE_CUDA
|
||||
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
|
||||
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, &value,
|
||||
tensor_size_bytes);
|
||||
#else
|
||||
*data = value;
|
||||
#endif
|
||||
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_SetOutput(ctx, 0, output, s);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||
|
||||
TF_DeleteStatus(s);
|
||||
TF_DeleteTensor(output);
|
||||
};
|
||||
|
||||
SetupOp("AllocateOutputOp1", "AllocateOutput1", my_compute_func);
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor* output = GetOutput(0);
|
||||
EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
|
||||
output->DebugString(100));
|
||||
}
|
||||
|
||||
REGISTER_OP("AllocateOutputOp0").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) {
|
||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
// Allocate empty output
|
||||
int64_t dim = 0;
|
||||
TF_Tensor* output = TF_AllocateOutput(
|
||||
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
|
||||
/*num_dims=*/1, /*len=*/0);
|
||||
|
||||
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
|
||||
EXPECT_EQ(1, TF_NumDims(output));
|
||||
EXPECT_EQ(0, TF_Dim(output, 0));
|
||||
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_SetOutput(ctx, 0, output, s);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||
|
||||
TF_DeleteStatus(s);
|
||||
TF_DeleteTensor(output);
|
||||
};
|
||||
|
||||
SetupOp("AllocateOutputOp0", "AllocateOutput0", my_compute_func);
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor* output = GetOutput(0);
|
||||
EXPECT_EQ("Tensor<type: float shape: [0] values: >",
|
||||
output->DebugString(100));
|
||||
}
|
||||
|
||||
REGISTER_OP("AllocateOutputOp2x3").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
|
||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
// Allocate 2x3 output
|
||||
int64_t dim[2] = {2, 3};
|
||||
size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
|
||||
TF_Tensor* output = TF_AllocateOutput(
|
||||
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
|
||||
/*num_dims=*/2, /*len=*/tensor_size_bytes);
|
||||
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
|
||||
EXPECT_EQ(2, TF_NumDims(output));
|
||||
EXPECT_EQ(2, TF_Dim(output, 0));
|
||||
EXPECT_EQ(3, TF_Dim(output, 1));
|
||||
|
||||
// Set output to [1 2 3 4 5 6]
|
||||
void* data = TF_TensorData(output);
|
||||
float value[6] = {1, 2, 3, 4, 5, 6};
|
||||
#if GOOGLE_CUDA
|
||||
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
|
||||
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, value,
|
||||
tensor_size_bytes);
|
||||
#else
|
||||
memcpy(data, value, tensor_size_bytes);
|
||||
#endif
|
||||
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_SetOutput(ctx, 0, output, s);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||
|
||||
TF_DeleteStatus(s);
|
||||
TF_DeleteTensor(output);
|
||||
};
|
||||
|
||||
SetupOp("AllocateOutputOp2x3", "AllocateOutput2x3", my_compute_func);
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor* output = GetOutput(0);
|
||||
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
|
||||
output->DebugString(100));
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -31,6 +31,37 @@ using tensorflow::TensorBuffer;
|
||||
using tensorflow::errors::FailedPrecondition;
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
|
||||
namespace tensorflow {
|
||||
void* allocate_tensor(const char* operation, size_t len, Allocator* allocator) {
|
||||
void* data = allocator->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
|
||||
if (LogMemory::IsEnabled() && data != nullptr) {
|
||||
LogMemory::RecordRawAllocation(
|
||||
operation, LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, len, data,
|
||||
allocator);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
void* allocate_tensor(const char* operation, size_t len) {
|
||||
return allocate_tensor(operation, len, cpu_allocator());
|
||||
}
|
||||
|
||||
void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
Allocator* allocator = nullptr;
|
||||
if (arg == nullptr) {
|
||||
allocator = cpu_allocator();
|
||||
} else {
|
||||
allocator = reinterpret_cast<Allocator*>(arg);
|
||||
}
|
||||
if (LogMemory::IsEnabled() && data != nullptr) {
|
||||
LogMemory::RecordRawDeallocation(
|
||||
"TensorFlow C Api", LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
|
||||
allocator, false);
|
||||
}
|
||||
allocator->DeallocateRaw(data);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
class TF_ManagedBuffer : public TensorBuffer {
|
||||
public:
|
||||
@ -63,36 +94,15 @@ class TF_ManagedBuffer : public TensorBuffer {
|
||||
bool OwnsMemory() const override { return false; }
|
||||
};
|
||||
|
||||
void* allocate_tensor(const char* operation, size_t len) {
|
||||
void* data =
|
||||
tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
|
||||
if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
|
||||
tensorflow::LogMemory::RecordRawAllocation(
|
||||
operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
|
||||
len, data, tensorflow::cpu_allocator());
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
|
||||
tensorflow::LogMemory::RecordRawDeallocation(
|
||||
"TensorFlow C Api",
|
||||
tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
|
||||
tensorflow::cpu_allocator(), false);
|
||||
}
|
||||
tensorflow::cpu_allocator()->DeallocateRaw(data);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TF_Tensor::~TF_Tensor() { buffer->Unref(); }
|
||||
|
||||
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
||||
int num_dims, size_t len) {
|
||||
void* data = allocate_tensor("TF_AllocateTensor", len);
|
||||
return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer,
|
||||
nullptr);
|
||||
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
|
||||
tensorflow::cpu_allocator());
|
||||
return TF_NewTensor(dtype, dims, num_dims, data, len,
|
||||
tensorflow::deallocate_buffer,
|
||||
tensorflow::cpu_allocator());
|
||||
}
|
||||
|
||||
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
@ -117,8 +127,8 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
//
|
||||
// Other types have the same representation, so copy only if it is safe to
|
||||
// do so.
|
||||
buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len,
|
||||
deallocate_buffer, nullptr);
|
||||
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
|
||||
len, tensorflow::deallocate_buffer, nullptr);
|
||||
std::memcpy(buf->data(), data, len);
|
||||
// Free the original buffer.
|
||||
deallocator(data, len, deallocator_arg);
|
||||
@ -126,9 +136,12 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
}
|
||||
|
||||
TF_Tensor* ret = new TF_Tensor{dtype, tensorflow::TensorShape(dimvec), buf};
|
||||
TF_Tensor* ret =
|
||||
new TF_Tensor{Tensor(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf)};
|
||||
buf->Unref();
|
||||
size_t elem_size = TF_DataTypeSize(dtype);
|
||||
if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
|
||||
if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) {
|
||||
delete ret;
|
||||
return nullptr;
|
||||
}
|
||||
@ -139,7 +152,7 @@ TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
|
||||
// It is safe to move the Tensor if and only if we own the unique reference to
|
||||
// it. In that case, we might as well not delete and reallocate, but a future
|
||||
// implementation might need to do so.
|
||||
TensorBuffer* buf = tensor->buffer;
|
||||
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor);
|
||||
if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
|
||||
buf->OwnsMemory()) {
|
||||
return tensor;
|
||||
@ -149,13 +162,23 @@ TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
|
||||
|
||||
void TF_DeleteTensor(TF_Tensor* t) { delete t; }
|
||||
|
||||
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
|
||||
int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
|
||||
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
|
||||
return static_cast<int64_t>(t->shape.dim_size(dim_index));
|
||||
TF_DataType TF_TensorType(const TF_Tensor* t) {
|
||||
return static_cast<TF_DataType>(t->tensor.dtype());
|
||||
}
|
||||
|
||||
int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); }
|
||||
|
||||
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
|
||||
return static_cast<int64_t>(t->tensor.dim_size(dim_index));
|
||||
}
|
||||
|
||||
size_t TF_TensorByteSize(const TF_Tensor* t) {
|
||||
return tensorflow::TensorCApi::Buffer(t->tensor)->size();
|
||||
}
|
||||
|
||||
void* TF_TensorData(const TF_Tensor* t) {
|
||||
return tensorflow::TensorCApi::Buffer(t->tensor)->data();
|
||||
}
|
||||
size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
|
||||
void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
|
||||
|
||||
int64_t TF_TensorElementCount(const TF_Tensor* t) {
|
||||
int64_t result = 1;
|
||||
@ -166,63 +189,17 @@ int64_t TF_TensorElementCount(const TF_Tensor* t) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the number of elements that would be present in a tensor with the
|
||||
// given shape.
|
||||
static int64_t ShapeNumElements(const int64_t* dims, int num_dims) {
|
||||
int64_t result = 1;
|
||||
for (int dim = 0; dim < num_dims; ++dim) {
|
||||
result *= dims[dim];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void UnrefIfNonNull(::tensorflow::TensorBuffer* buf) {
|
||||
if (buf != nullptr) {
|
||||
buf->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
static void RefIfNonNull(::tensorflow::TensorBuffer* buf) {
|
||||
if (buf != nullptr) {
|
||||
buf->Ref();
|
||||
}
|
||||
}
|
||||
|
||||
void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
|
||||
TF_Tensor* to, const int64_t* new_dims,
|
||||
int num_new_dims, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
size_t in_size = TF_DataTypeSize(TF_TensorType(from));
|
||||
if (in_size == 0) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"input tensor has a zero-sized data type");
|
||||
return;
|
||||
}
|
||||
size_t out_size = TF_DataTypeSize(type);
|
||||
if (out_size == 0) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"output tensor has a zero-sized data type");
|
||||
return;
|
||||
}
|
||||
|
||||
if (ShapeNumElements(new_dims, num_new_dims) * out_size !=
|
||||
TF_TensorElementCount(from) * in_size) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"input tensor is not compatible with output shape");
|
||||
return;
|
||||
}
|
||||
|
||||
tensorflow::TensorShapeProto p;
|
||||
tensorflow::TensorShape s;
|
||||
for (int i = 0; i < num_new_dims; ++i) {
|
||||
p.add_dim()->set_size(new_dims[i]);
|
||||
}
|
||||
to->shape = tensorflow::TensorShape(p);
|
||||
to->dtype = type;
|
||||
if (to->buffer != from->buffer) {
|
||||
UnrefIfNonNull(to->buffer);
|
||||
to->buffer = from->buffer;
|
||||
RefIfNonNull(to->buffer);
|
||||
s.AddDim(new_dims[i]);
|
||||
}
|
||||
Status cc_status(to->tensor.BitcastFrom(
|
||||
from->tensor, static_cast<tensorflow::DataType>(type), s));
|
||||
Set_TF_Status_from_Status(status, cc_status);
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@ -332,17 +309,19 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
return t;
|
||||
}
|
||||
if (src.dtype() != tensorflow::DT_STRING) {
|
||||
TensorBuffer* buf = tensorflow::TensorCApi::Buffer(src);
|
||||
buf->Ref();
|
||||
return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
|
||||
buf};
|
||||
auto* result = new TF_Tensor();
|
||||
if (!result->tensor.CopyFrom(src, src.shape())) {
|
||||
delete result;
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
|
||||
// encoded sequence of strings.
|
||||
|
||||
// Compute bytes needed for encoding.
|
||||
size_t size = 0;
|
||||
const auto& srcarray = src.flat<string>();
|
||||
const auto& srcarray = src.flat<tstring>();
|
||||
for (int i = 0; i < srcarray.size(); ++i) {
|
||||
const string& s = srcarray(i);
|
||||
// uint64 starting_offset, TF_StringEncode-d string.
|
||||
@ -393,14 +372,14 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
}
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
if (src->dtype == TF_RESOURCE) {
|
||||
if (src->shape.dims() != 0) {
|
||||
if (src->tensor.dtype() == DT_RESOURCE) {
|
||||
if (src->tensor.dims() != 0) {
|
||||
return InvalidArgument(
|
||||
"Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
|
||||
"shape ",
|
||||
src->shape.DebugString());
|
||||
src->tensor.shape().DebugString());
|
||||
}
|
||||
*dst = Tensor(tensorflow::DT_RESOURCE, src->shape);
|
||||
*dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape());
|
||||
if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
|
||||
string(static_cast<const char*>(TF_TensorData(src)),
|
||||
TF_TensorByteSize(src)))) {
|
||||
@ -409,14 +388,13 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (src->dtype != TF_STRING) {
|
||||
*dst =
|
||||
tensorflow::TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
|
||||
if (src->tensor.dtype() != DT_STRING) {
|
||||
*dst = src->tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
// TF_STRING tensors require copying since Tensor class expects a sequence of
|
||||
// string objects.
|
||||
const tensorflow::int64 num_elements = src->shape.num_elements();
|
||||
const tensorflow::int64 num_elements = src->tensor.NumElements();
|
||||
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
|
||||
const size_t src_size = TF_TensorByteSize(src);
|
||||
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
|
||||
@ -427,8 +405,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
||||
const char* limit = input + src_size;
|
||||
|
||||
*dst = Tensor(static_cast<tensorflow::DataType>(src->dtype), src->shape);
|
||||
auto dstarray = dst->flat<string>();
|
||||
*dst = Tensor(src->tensor.dtype(), src->tensor.shape());
|
||||
auto dstarray = dst->flat<tstring>();
|
||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
||||
tensorflow::uint64 offset =
|
||||
reinterpret_cast<const tensorflow::uint64*>(input)[i];
|
||||
@ -447,3 +425,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
bool TF_TensorIsAligned(const TF_Tensor* tensor) {
|
||||
return tensor->tensor.IsAligned();
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_TF_TENSOR_H_
|
||||
#define TENSORFLOW_C_TF_TENSOR_H_
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
@ -175,6 +176,9 @@ TF_CAPI_EXPORT extern size_t TF_StringDecode(const char* src, size_t src_len,
|
||||
// TF_STRING tensor.
|
||||
TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len);
|
||||
|
||||
// Returns bool iff this tensor is aligned.
|
||||
TF_CAPI_EXPORT extern bool TF_TensorIsAligned(const TF_Tensor*);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -23,13 +23,12 @@ limitations under the License.
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
// not be depended on.
|
||||
|
||||
struct TF_Tensor {
|
||||
~TF_Tensor();
|
||||
|
||||
TF_DataType dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
tensorflow::TensorBuffer* buffer;
|
||||
};
|
||||
// This struct forms part of the C API's public interface. It must strictly be
|
||||
// passed to or returned from C functions *by pointer*. Otherwise, changes to
|
||||
// its internal structure will break the C API's binary interface.
|
||||
typedef struct TF_Tensor {
|
||||
::tensorflow::Tensor tensor;
|
||||
} TF_Tensor;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -42,5 +41,13 @@ class TensorCApi {
|
||||
}
|
||||
};
|
||||
|
||||
// Allocates tensor data buffer using specified allocator.
|
||||
// `operation` is a name for this operation.
|
||||
void* allocate_tensor(const char* operation, size_t len, Allocator* allocator);
|
||||
|
||||
// Deallocates tensor data buffer.
|
||||
// Defaults to deallocating using CPU allocator. You can pass pointer to
|
||||
// a different Allocator as `arg`.
|
||||
void deallocate_buffer(void* data, size_t len, void* arg);
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
@ -649,7 +649,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
@ -667,7 +666,6 @@ tf_cc_test(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb_text.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
@ -193,12 +193,12 @@ string PrintTensor(const TensorProto& tensor_proto) {
|
||||
string ret;
|
||||
for (int64 i = 0; i < num_elts; ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, " ");
|
||||
strings::StrAppend(&ret, absl::CEscape(t.flat<string>()(i)));
|
||||
strings::StrAppend(&ret, absl::CEscape(t.flat<tstring>()(i)));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Not handling type " << EnumName_DataType(t.dtype());
|
||||
LOG(FATAL) << "Not handling type " << DataType_Name(t.dtype());
|
||||
return string();
|
||||
}
|
||||
}
|
||||
@ -223,7 +223,7 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) {
|
||||
case AttrValue::kB:
|
||||
return attr_value.b() ? "true" : "false";
|
||||
case AttrValue::kType:
|
||||
return EnumName_DataType(attr_value.type());
|
||||
return DataType_Name(attr_value.type());
|
||||
case AttrValue::kShape:
|
||||
return PrintTensorShape(attr_value.shape());
|
||||
case AttrValue::kTensor:
|
||||
@ -254,8 +254,7 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) {
|
||||
} else if (attr_value.list().type_size() > 0) {
|
||||
for (int i = 0; i < attr_value.list().type_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(&ret,
|
||||
EnumName_DataType(attr_value.list().type(i)));
|
||||
strings::StrAppend(&ret, DataType_Name(attr_value.list().type(i)));
|
||||
}
|
||||
} else if (attr_value.list().shape_size() > 0) {
|
||||
for (int i = 0; i < attr_value.list().shape_size(); ++i) {
|
||||
|
@ -200,10 +200,10 @@ TEST(CCOpTest, TemplatedConst) {
|
||||
test::ExpectTensorEqual<float>(
|
||||
out, test::AsTensor<float>({3.f, 2.f, -1.f, 0.f}, {2, 2}));
|
||||
|
||||
auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
|
||||
auto c2 = ops::Const<tstring>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
|
||||
test::GetTensor(root, c2, &out);
|
||||
test::ExpectTensorEqual<string>(
|
||||
out, test::AsTensor<string>({"this", "is", "a", "constant"}, {4, 1}));
|
||||
test::ExpectTensorEqual<tstring>(
|
||||
out, test::AsTensor<tstring>({"this", "is", "a", "constant"}, {4, 1}));
|
||||
}
|
||||
|
||||
TEST(CCOpTest, EmptyConst) {
|
||||
|
@ -97,7 +97,7 @@ Input::Initializer::Initializer(
|
||||
Tensor elem = e.tensor;
|
||||
if (first.tensor.dtype() == DT_STRING) {
|
||||
for (int i = 0; i < elem.NumElements(); ++i) {
|
||||
t.flat<string>()(offset + i) = elem.flat<string>()(i);
|
||||
t.flat<tstring>()(offset + i) = elem.flat<tstring>()(i);
|
||||
}
|
||||
offset += elem.NumElements();
|
||||
} else {
|
||||
|
@ -111,7 +111,7 @@ class Input {
|
||||
Initializer(const T& v) { // NOLINT(runtime/explicit)
|
||||
typedef typename RealType<T>::type RealT;
|
||||
Tensor t(DataTypeToEnum<RealT>::v(), TensorShape());
|
||||
t.flat<T>()(0) = RealT(v);
|
||||
t.flat<RealT>()(0) = RealT(v);
|
||||
tensor = t;
|
||||
}
|
||||
|
||||
@ -125,7 +125,7 @@ class Input {
|
||||
typedef typename RealType<T>::type RealT;
|
||||
Tensor t(DataTypeToEnum<RealT>::v(), shape);
|
||||
for (int64 i = 0; i < t.NumElements(); ++i) {
|
||||
t.flat<T>()(i) = RealT(v);
|
||||
t.flat<RealT>()(i) = RealT(v);
|
||||
}
|
||||
tensor = t;
|
||||
}
|
||||
@ -170,7 +170,7 @@ class Input {
|
||||
// START_SKIP_DOXYGEN
|
||||
template <typename T, bool = std::is_convertible<T, string>::value>
|
||||
struct RealType {
|
||||
typedef string type;
|
||||
typedef tstring type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -272,7 +272,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
|
||||
std::unordered_set<string> current_constraints(colocation_constraints_);
|
||||
const AttrSlice attrs = colocate_with_op.node()->attrs();
|
||||
std::vector<string> node_constraints;
|
||||
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
|
||||
if (TryGetNodeAttr(attrs, kColocationAttrName, &node_constraints)) {
|
||||
for (const string& entry : node_constraints) {
|
||||
StringPiece s(entry);
|
||||
if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) {
|
||||
@ -299,7 +299,7 @@ const std::vector<Operation>& Scope::control_deps() const {
|
||||
return impl()->control_deps_;
|
||||
}
|
||||
|
||||
void Scope::UpdateStatus(const Status s) const {
|
||||
void Scope::UpdateStatus(const Status& s) const {
|
||||
impl()->status_->Update(s);
|
||||
if (impl()->exit_on_error_ && !ok()) {
|
||||
LOG(FATAL) << *impl()->status_;
|
||||
@ -318,7 +318,7 @@ Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const {
|
||||
if (ok()) {
|
||||
GraphDef graph_def;
|
||||
graph()->ToGraphDef(&graph_def);
|
||||
UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
|
||||
UpdateStatus(ConvertGraphDefToGraph(opts, std::move(graph_def), g));
|
||||
}
|
||||
return *impl()->status_;
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ class Scope {
|
||||
/// Note: The status object is shared between all children of this scope.
|
||||
/// If the resulting status is not Status::OK() and exit_on_error_ is set on
|
||||
/// this scope, this function exits by calling LOG(FATAL).
|
||||
void UpdateStatus(const Status s) const;
|
||||
void UpdateStatus(const Status& s) const;
|
||||
|
||||
// START_SKIP_DOXYGEN
|
||||
|
||||
|
@ -97,7 +97,7 @@ TEST(ConstOpTest, WithExplicitShape) {
|
||||
auto d = ops::Const(root, {"1", "2", "3", "4", "5", "6"}, {2, 3});
|
||||
TF_CHECK_OK(root.status());
|
||||
EXPECT_EQ(d.op().output_type(0), DT_STRING);
|
||||
ExpectNodeEqual<string>(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3});
|
||||
ExpectNodeEqual<tstring>(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3});
|
||||
}
|
||||
|
||||
TEST(ConstOpTest, FromProto) {
|
||||
@ -144,7 +144,7 @@ TEST(ConstOpTest, TemplatedConst) {
|
||||
auto c1 = ops::Const<int>(root, {1, 2});
|
||||
ExpectTypeAndShape(c1.node(), DT_INT32, {2});
|
||||
|
||||
auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
|
||||
auto c2 = ops::Const<tstring>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
|
||||
ExpectTypeAndShape(c2.node(), DT_STRING, {4, 1});
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@ tf_cuda_cc_test(
|
||||
name = "profiler_test",
|
||||
srcs = ["profiler_test.cc"],
|
||||
tags = [
|
||||
"no_rocm", # stream level tracing not supported on ROCm
|
||||
"nogpu", # b/77649654
|
||||
],
|
||||
deps = [
|
||||
|
@ -10,7 +10,7 @@ load(
|
||||
"tf_cc_test",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"//tensorflow/core/platform:default/build_config_root.bzl",
|
||||
"if_static",
|
||||
"if_static_and_not_mobile",
|
||||
)
|
||||
|
@ -75,7 +75,7 @@ Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||
|
||||
Tensor CreateStringTensor(const string& value) {
|
||||
Tensor tensor(DT_STRING, TensorShape({}));
|
||||
tensor.scalar<string>()() = value;
|
||||
tensor.scalar<tstring>()() = value;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
@ -219,7 +219,7 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
|
||||
// Add variables to the graph.
|
||||
Tensor variables_path_tensor(DT_STRING, TensorShape({}));
|
||||
variables_path_tensor.scalar<string>()() = variables_path;
|
||||
variables_path_tensor.scalar<tstring>()() = variables_path;
|
||||
|
||||
std::vector<std::pair<string, Tensor>> inputs = {
|
||||
{string(variable_filename_const_op_name), variables_path_tensor}};
|
||||
|
@ -63,8 +63,8 @@ class LoaderTest : public ::testing::Test {
|
||||
bundle.session->Run({}, {"filename_tensor:0"}, {}, &path_outputs));
|
||||
ASSERT_EQ(1, path_outputs.size());
|
||||
|
||||
test::ExpectTensorEqual<string>(
|
||||
test::AsTensor<string>({"foo.txt"}, TensorShape({})), path_outputs[0]);
|
||||
test::ExpectTensorEqual<tstring>(
|
||||
test::AsTensor<tstring>({"foo.txt"}, TensorShape({})), path_outputs[0]);
|
||||
}
|
||||
|
||||
void CheckSavedModelBundle(const string& export_dir,
|
||||
@ -78,14 +78,14 @@ class LoaderTest : public ::testing::Test {
|
||||
const string output_name =
|
||||
signature_def.outputs().at(kRegressOutputs).name();
|
||||
|
||||
std::vector<string> serialized_examples;
|
||||
std::vector<tstring> serialized_examples;
|
||||
for (float x : {0, 1, 2, 3}) {
|
||||
serialized_examples.push_back(MakeSerializedExample(x));
|
||||
}
|
||||
|
||||
// Validate the half plus two behavior.
|
||||
Tensor input =
|
||||
test::AsTensor<string>(serialized_examples, TensorShape({4}));
|
||||
test::AsTensor<tstring>(serialized_examples, TensorShape({4}));
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(bundle.session->Run({{input_name, input}}, {output_name}, {},
|
||||
&outputs));
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Description:
|
||||
# CLIF wrappers for TensorFlow SavedModels.
|
||||
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc")
|
||||
load("//tensorflow/core/platform:default/build_config.bzl", "tf_py_clif_cc")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
|
@ -48,12 +48,12 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
||||
export_dir);
|
||||
}
|
||||
|
||||
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||
const std::unordered_set<string>& tags,
|
||||
Status FindMetaGraphDef(const std::unordered_set<string>& tags,
|
||||
SavedModel* saved_model_proto,
|
||||
MetaGraphDef* meta_graph_def) {
|
||||
LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
|
||||
<< " }";
|
||||
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
|
||||
for (MetaGraphDef& graph_def : *saved_model_proto->mutable_meta_graphs()) {
|
||||
// Get tags from the graph_def.
|
||||
std::unordered_set<string> graph_tags;
|
||||
for (const string& tag : graph_def.meta_info_def().tags()) {
|
||||
@ -61,7 +61,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||
}
|
||||
// Match with the set of tags provided.
|
||||
if (graph_tags == tags) {
|
||||
*meta_graph_def = graph_def;
|
||||
*meta_graph_def = std::move(graph_def);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
@ -81,7 +81,8 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
MetaGraphDef* const meta_graph_def) {
|
||||
SavedModel saved_model_proto;
|
||||
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
|
||||
TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -42,6 +42,10 @@ void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
|
||||
tensor_names->insert(coo_sparse.values_tensor_name());
|
||||
tensor_names->insert(coo_sparse.indices_tensor_name());
|
||||
tensor_names->insert(coo_sparse.dense_shape_tensor_name());
|
||||
} else if (tensor_info.has_composite_tensor()) {
|
||||
for (const auto& component : tensor_info.composite_tensor().components()) {
|
||||
tensor_names->insert(component.name());
|
||||
}
|
||||
} else {
|
||||
tensor_names->insert(tensor_info.name());
|
||||
}
|
||||
|
@ -425,5 +425,63 @@ TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) {
|
||||
TestFreezeGraphWithAndWithoutDependentVariables(true);
|
||||
}
|
||||
|
||||
TEST_F(FreezeTest, InputsAndOutputsCompositeTensorSignatureDef) {
|
||||
// Test that inputs and outputs get correctly populated for a
|
||||
// SignatureDef containing composite tensor inputs and outputs.
|
||||
SavedModelBundle saved_model_bundle;
|
||||
SignatureDef signature_def;
|
||||
|
||||
TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
|
||||
in.mutable_composite_tensor()->add_components()->set_name("input1:0");
|
||||
in.mutable_composite_tensor()->add_components()->set_name("input2:0");
|
||||
|
||||
TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
|
||||
out.mutable_composite_tensor()->add_components()->set_name("output2:0");
|
||||
out.mutable_composite_tensor()->add_components()->set_name("output1:0");
|
||||
|
||||
AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
|
||||
&saved_model_bundle);
|
||||
GraphDef frozen_graph_def;
|
||||
std::unordered_set<string> inputs;
|
||||
std::unordered_set<string> outputs;
|
||||
TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
|
||||
&outputs));
|
||||
std::unordered_set<string> expected_inputs = {"input1:0", "input2:0"};
|
||||
std::unordered_set<string> expected_outputs = {"output1:0", "output2:0"};
|
||||
EXPECT_EQ(expected_inputs, inputs);
|
||||
EXPECT_EQ(expected_outputs, outputs);
|
||||
}
|
||||
|
||||
TEST_F(FreezeTest, InputsAndOutputsSparseCooSignatureDef) {
|
||||
// Test that inputs and outputs get correctly populated for a
|
||||
// SignatureDef containing composite tensor inputs and outputs.
|
||||
SavedModelBundle saved_model_bundle;
|
||||
SignatureDef signature_def;
|
||||
|
||||
TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
|
||||
in.mutable_coo_sparse()->set_values_tensor_name("input1:0");
|
||||
in.mutable_coo_sparse()->set_indices_tensor_name("input2:0");
|
||||
in.mutable_coo_sparse()->set_dense_shape_tensor_name("input3:0");
|
||||
|
||||
TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
|
||||
out.mutable_coo_sparse()->set_values_tensor_name("output1:0");
|
||||
out.mutable_coo_sparse()->set_indices_tensor_name("output2:0");
|
||||
out.mutable_coo_sparse()->set_dense_shape_tensor_name("output3:0");
|
||||
|
||||
AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
|
||||
&saved_model_bundle);
|
||||
GraphDef frozen_graph_def;
|
||||
std::unordered_set<string> inputs;
|
||||
std::unordered_set<string> outputs;
|
||||
TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
|
||||
&outputs));
|
||||
std::unordered_set<string> expected_inputs = {"input1:0", "input2:0",
|
||||
"input3:0"};
|
||||
std::unordered_set<string> expected_outputs = {"output1:0", "output2:0",
|
||||
"output3:0"};
|
||||
EXPECT_EQ(expected_inputs, inputs);
|
||||
EXPECT_EQ(expected_outputs, outputs);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,7 +1,7 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -144,8 +144,57 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
XLA_DEVICE_DEPS = [
|
||||
":common",
|
||||
":xla_launch_util",
|
||||
":xla_tensor",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:stream_pool",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
"//tensorflow/core:control_flow_ops_op_lib",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:math_ops_op_lib",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//tensorflow/core:no_op_op_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:resource_variable_ops_op_lib",
|
||||
"//tensorflow/core:sendrecv_ops_op_lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:fifo_queue",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/kernels/data:generator_dataset_op",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
"//tensorflow/core/kernels/data:optional_ops",
|
||||
"//tensorflow/core/kernels/data:prefetch_dataset_op",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
]
|
||||
|
||||
cc_library(
|
||||
name = "xla_device",
|
||||
name = "xla_device_no_jit_rewrite_registration",
|
||||
srcs = [
|
||||
"xla_compile_on_demand_op.cc",
|
||||
"xla_device.cc",
|
||||
@ -158,56 +207,22 @@ cc_library(
|
||||
"xla_device_context.h",
|
||||
"xla_device_ops.h",
|
||||
],
|
||||
deps = XLA_DEVICE_DEPS,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_device",
|
||||
hdrs = [
|
||||
"xla_compile_on_demand_op.h",
|
||||
"xla_device.h",
|
||||
"xla_device_context.h",
|
||||
"xla_device_ops.h",
|
||||
],
|
||||
# Public visibility is needed for external TF/XLA backends.
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":common",
|
||||
deps = XLA_DEVICE_DEPS + [
|
||||
":jit_compilation_passes",
|
||||
":xla_launch_util",
|
||||
":xla_tensor",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:stream_pool",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
"//tensorflow/core:control_flow_ops_op_lib",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:math_ops_op_lib",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//tensorflow/core:no_op_op_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:resource_variable_ops_op_lib",
|
||||
"//tensorflow/core:sendrecv_ops_op_lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:fifo_queue",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/kernels/data:generator_dataset_op",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
"//tensorflow/core/kernels/data:optional_ops",
|
||||
"//tensorflow/core/kernels/data:prefetch_dataset_op",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
":xla_device_no_jit_rewrite_registration",
|
||||
],
|
||||
)
|
||||
|
||||
@ -281,6 +296,7 @@ cc_library(
|
||||
hdrs = ["xla_compilation_cache.h"],
|
||||
deps = [
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -292,6 +308,8 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
@ -324,17 +342,21 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Linked by tensorflow core, without registration of jit compilation passes
|
||||
# which is not necessary to create and run a XlaLocalLaunchBase kernel.
|
||||
# Linking jit compilation passes could cause programs stuck right now (b/140069592).
|
||||
cc_library(
|
||||
name = "xla_kernel_creator",
|
||||
name = "xla_kernel_creator_util",
|
||||
srcs = [
|
||||
"xla_kernel_creator.cc",
|
||||
"xla_kernel_creator.h",
|
||||
"xla_kernel_creator_util.cc",
|
||||
],
|
||||
hdrs = ["xla_kernel_creator_util.h"],
|
||||
visibility = ["//tensorflow/core/common_runtime/eager:__pkg__"],
|
||||
deps = [
|
||||
":common",
|
||||
":compilability_check_util",
|
||||
":compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -347,6 +369,23 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_kernel_creator",
|
||||
srcs = [
|
||||
"xla_kernel_creator.cc",
|
||||
"xla_kernel_creator.h",
|
||||
],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_kernel_creator_test",
|
||||
srcs = [
|
||||
@ -498,6 +537,7 @@ cc_library(
|
||||
srcs = [
|
||||
"build_xla_ops_pass.cc",
|
||||
"clone_constants_for_better_clustering.cc",
|
||||
"cluster_scoping_pass.cc",
|
||||
"deadness_analysis.cc",
|
||||
"deadness_analysis_internal.h",
|
||||
"encapsulate_subgraphs_pass.cc",
|
||||
@ -513,6 +553,7 @@ cc_library(
|
||||
hdrs = [
|
||||
"build_xla_ops_pass.h",
|
||||
"clone_constants_for_better_clustering.h",
|
||||
"cluster_scoping_pass.h",
|
||||
"deadness_analysis.h",
|
||||
"encapsulate_subgraphs_pass.h",
|
||||
"encapsulate_xla_computations_pass.h",
|
||||
@ -677,6 +718,7 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"build_xla_ops_pass_test.cc",
|
||||
"clone_constants_for_better_clustering_test.cc",
|
||||
"cluster_scoping_pass_test.cc",
|
||||
"encapsulate_subgraphs_pass_test.cc",
|
||||
"encapsulate_xla_computations_pass_test.cc",
|
||||
"extract_outside_compilation_pass_test.cc",
|
||||
@ -800,6 +842,8 @@ cc_library(
|
||||
":flags",
|
||||
":resource_operation_safety_analysis",
|
||||
":union_find",
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
@ -837,6 +881,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -901,6 +946,7 @@ cc_library(
|
||||
srcs = ["xla_activity_logging_listener.cc"],
|
||||
deps = [
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/core:logger",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
|
@ -48,6 +48,19 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
struct DebuggingOpts {
|
||||
// If true, insert Print nodes to print every output from an XLA cluster.
|
||||
bool print_outputs;
|
||||
|
||||
// If true, insert CheckNumerics nodes for every floating point typed input to
|
||||
// an XLA cluster.
|
||||
bool check_input_numerics;
|
||||
|
||||
// If true, insert CheckNumerics nodes for every floating point typed output
|
||||
// from an XLA cluster.
|
||||
bool check_output_numerics;
|
||||
};
|
||||
|
||||
void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
|
||||
std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
|
||||
old_node->out_edges().end());
|
||||
@ -78,7 +91,8 @@ Operation DataToControl(const Scope& scope, Output data) {
|
||||
// Replaces each outgoing edge from `old_node` with a merge node that merges in
|
||||
// the corresponding output from `new_node`.
|
||||
void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
|
||||
bool insert_print_nodes) {
|
||||
absl::string_view cluster_name,
|
||||
const DebuggingOpts& debugging_opts) {
|
||||
if (!s.status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -93,23 +107,36 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
|
||||
int oidx = e->src_output();
|
||||
Output merged_output = merged_outputs[oidx];
|
||||
if (merged_output.node() == nullptr) {
|
||||
ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)),
|
||||
{Output(old_node, oidx), Output(new_node, oidx)});
|
||||
if (insert_print_nodes) {
|
||||
Output new_output(new_node, oidx);
|
||||
if (debugging_opts.print_outputs) {
|
||||
string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
ops::Print print_op(s.WithOpName(absl::StrCat("print_", oidx))
|
||||
ops::Print print_op(s.WithOpName("print_", oidx)
|
||||
.WithDevice(cpu_device)
|
||||
.WithAssignedDevice(cpu_device),
|
||||
merge_op.output, {merge_op.output},
|
||||
new_output, {new_output},
|
||||
ops::Print::Attrs{}
|
||||
.Message(absl::StrCat("output ", oidx, " from ",
|
||||
old_node->name(), " is "))
|
||||
.FirstN(1000)
|
||||
.Summarize(-1));
|
||||
merged_output = merged_outputs[oidx] = print_op;
|
||||
} else {
|
||||
merged_output = merged_outputs[oidx] = merge_op.output;
|
||||
new_output = print_op;
|
||||
}
|
||||
|
||||
if (debugging_opts.check_output_numerics &&
|
||||
DataTypeIsFloating(new_output.type())) {
|
||||
ops::CheckNumerics check_numerics_op(
|
||||
s.WithOpName("check_output_", oidx)
|
||||
.WithDevice(new_node->requested_device())
|
||||
.WithAssignedDevice(new_node->assigned_device_name()),
|
||||
new_output,
|
||||
absl::StrCat("CheckNumerics failed for output ", oidx, "(",
|
||||
new_output.name(), ") from cluster ", cluster_name));
|
||||
new_output = check_numerics_op;
|
||||
}
|
||||
|
||||
ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx),
|
||||
{Output(old_node, oidx), new_output});
|
||||
merged_output = merged_outputs[oidx] = merge_op.output;
|
||||
}
|
||||
|
||||
Node* dst = e->dst();
|
||||
@ -324,11 +351,34 @@ xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<Output> GetXlaRunArgs(const Scope& s,
|
||||
const XlaClusterInfo& cluster_info,
|
||||
const DebuggingOpts& debugging_opts) {
|
||||
std::vector<Output> xla_run_args;
|
||||
xla_run_args.reserve(cluster_info.non_constant_inputs.size() +
|
||||
cluster_info.resource_inputs.size());
|
||||
int input_idx = 0;
|
||||
for (const Output& o : cluster_info.non_constant_inputs) {
|
||||
if (debugging_opts.check_input_numerics && DataTypeIsFloating(o.type())) {
|
||||
ops::CheckNumerics check_numerics_op(
|
||||
s.WithOpName("check_input_", input_idx), o,
|
||||
absl::StrCat("CheckNumerics failed for input ", input_idx, "(",
|
||||
o.name(), ") into ", cluster_info.function.name()));
|
||||
xla_run_args.push_back(check_numerics_op);
|
||||
} else {
|
||||
xla_run_args.push_back(o);
|
||||
}
|
||||
input_idx++;
|
||||
}
|
||||
absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
|
||||
return xla_run_args;
|
||||
}
|
||||
|
||||
Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
jit::DeviceInfoCache* device_info_cache,
|
||||
const GraphOptimizationPassOptions& options,
|
||||
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
|
||||
bool insert_print_nodes, Graph* g, Node* n) {
|
||||
const DebuggingOpts& debugging_opts, Graph* g, Node* n) {
|
||||
XlaClusterInfo cluster_info;
|
||||
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
|
||||
|
||||
@ -361,12 +411,12 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
TF_RETURN_IF_ERROR(
|
||||
CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
|
||||
|
||||
std::vector<Output> xla_run_args =
|
||||
GetXlaRunArgs(root, cluster_info, debugging_opts);
|
||||
|
||||
if (requires_compilation) {
|
||||
// "Strict" compilation: every _XlaCompile invocation must compile the
|
||||
// cluster.
|
||||
std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
|
||||
absl::c_copy(cluster_info.resource_inputs,
|
||||
std::back_inserter(xla_run_args));
|
||||
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
|
||||
xla_compile.key, n->output_types());
|
||||
|
||||
@ -391,9 +441,6 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
Output predicated_compilation_key = s.output_true;
|
||||
Output inverse_predicated_compilation_key = s.output_false;
|
||||
|
||||
std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
|
||||
absl::c_copy(cluster_info.resource_inputs,
|
||||
std::back_inserter(xla_run_args));
|
||||
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
|
||||
predicated_compilation_key, n->output_types());
|
||||
|
||||
@ -402,7 +449,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
|
||||
MergeOutgoingDataEdges(root, /*old_node=*/n,
|
||||
/*new_node=*/xla_run.operation.node(),
|
||||
insert_print_nodes);
|
||||
cluster_info.function.name(), debugging_opts);
|
||||
|
||||
TF_RETURN_IF_ERROR(root.status());
|
||||
|
||||
@ -443,15 +490,25 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
enable_lazy_compilation_
|
||||
? *enable_lazy_compilation_
|
||||
: GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
|
||||
bool insert_print_nodes =
|
||||
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
|
||||
|
||||
jit::DeviceInfoCache device_info_cache;
|
||||
const BuildXlaOpsPassFlags& flags = *GetBuildXlaOpsPassFlags();
|
||||
|
||||
DebuggingOpts debugging_opts;
|
||||
debugging_opts.print_outputs = flags.tf_xla_print_cluster_outputs;
|
||||
debugging_opts.check_input_numerics =
|
||||
flags.tf_xla_check_cluster_input_numerics;
|
||||
debugging_opts.check_output_numerics =
|
||||
flags.tf_xla_check_cluster_output_numerics;
|
||||
|
||||
VLOG(1) << "print_outputs = " << debugging_opts.print_outputs;
|
||||
VLOG(1) << "check_input_numerics = " << debugging_opts.check_input_numerics;
|
||||
VLOG(1) << "check_output_numerics = " << debugging_opts.check_output_numerics;
|
||||
|
||||
for (Node* n : xla_compiled_kernels) {
|
||||
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
&device_info_cache, options, *options.flib_def,
|
||||
lazy_compilation_enabled, insert_print_nodes, graph, n));
|
||||
lazy_compilation_enabled, debugging_opts, graph, n));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
|
163
tensorflow/compiler/jit/cluster_scoping_pass.cc
Normal file
163
tensorflow/compiler/jit/cluster_scoping_pass.cc
Normal file
@ -0,0 +1,163 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class ClusterScopingPassImpl {
|
||||
public:
|
||||
ClusterScopingPassImpl(Graph* graph,
|
||||
OptimizerOptions::GlobalJitLevel global_jit_level)
|
||||
: graph_(graph),
|
||||
global_jit_level_(global_jit_level),
|
||||
unique_scope_id_(0) {}
|
||||
|
||||
Status Run();
|
||||
|
||||
private:
|
||||
Status ScopingForPipelineStages();
|
||||
|
||||
size_t GetUniqueScopeId() { return unique_scope_id_++; }
|
||||
|
||||
void AddScopeToAllTransitivePredecessors(Node* start);
|
||||
|
||||
void AddScopeToAllTransitiveSuccessors(Node* start);
|
||||
|
||||
private:
|
||||
Graph* graph_;
|
||||
OptimizerOptions::GlobalJitLevel global_jit_level_;
|
||||
size_t unique_scope_id_;
|
||||
};
|
||||
|
||||
absl::optional<string> GetXlaInternalScope(Node* node) {
|
||||
string scope;
|
||||
if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) {
|
||||
return scope;
|
||||
}
|
||||
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
void SetXlaInternalScope(Node* node, StringPiece scope) {
|
||||
node->AddAttr(kXlaInternalScopeAttr, scope);
|
||||
}
|
||||
|
||||
// NB! We append a new scope as suffix to the _XlaInternalScope attribute
|
||||
// instead of overriding the old value. In other words, appending scope B to
|
||||
// scope A creates the conjunction of the scopes A and B (i.e, A & B) and,
|
||||
// in effect, the node gets both the old and new scopes. As a unique scope
|
||||
// disallows a node being merged with nodes in other scopes, the scope
|
||||
// conjunction preserves the semantic of the old scope (i.e., the node still
|
||||
// cannot be merged with the previously incompatible nodes.)
|
||||
//
|
||||
// For example, the below case should be rare in practice but can serve for the
|
||||
// purpose of discussion. After adding scopes for both Stage and Unstage,
|
||||
// Node_Y will receive both scopes "unstage" and "stage", while Node_X receives
|
||||
// only scope "stage". The semantic of scope "unstage" is preserved although
|
||||
// scope "stage" is later appended. As a result, Node_X and Node_Y will be put
|
||||
// into different clusters.
|
||||
//
|
||||
// Unstage -> Node_Y (scope "unstage & stage")
|
||||
// |
|
||||
// V
|
||||
// Node_X (scope "stage") -> Stage
|
||||
//
|
||||
void AddOrAppendXlaInternalScope(Node* node, absl::string_view suffix) {
|
||||
string updated_scope;
|
||||
absl::optional<string> cur_scope = GetXlaInternalScope(node);
|
||||
if (cur_scope == absl::nullopt) {
|
||||
updated_scope = std::string(suffix);
|
||||
} else {
|
||||
updated_scope = absl::StrCat(cur_scope.value(), "&", suffix);
|
||||
}
|
||||
SetXlaInternalScope(node, updated_scope);
|
||||
}
|
||||
|
||||
void ClusterScopingPassImpl::AddScopeToAllTransitivePredecessors(Node* start) {
|
||||
const string unique_suffix = absl::StrCat("_", GetUniqueScopeId());
|
||||
|
||||
std::vector<Node*> starts;
|
||||
starts.push_back(start);
|
||||
auto enter = [&](Node* n) { AddOrAppendXlaInternalScope(n, unique_suffix); };
|
||||
ReverseDFSFrom(*graph_, starts, enter, /*leave=*/nullptr,
|
||||
/*stable_comparator=*/NodeComparatorName());
|
||||
}
|
||||
|
||||
void ClusterScopingPassImpl::AddScopeToAllTransitiveSuccessors(Node* start) {
|
||||
const string unique_suffix = absl::StrCat("_", GetUniqueScopeId());
|
||||
|
||||
std::vector<Node*> starts;
|
||||
starts.push_back(start);
|
||||
auto enter = [&](Node* n) { AddOrAppendXlaInternalScope(n, unique_suffix); };
|
||||
DFSFrom(*graph_, starts, enter, /*leave=*/nullptr,
|
||||
/*stable_comparator=*/NodeComparatorName(),
|
||||
// Do not filter any edges to better capture the semantics of
|
||||
// transitive closure of successors. We may revisit this when
|
||||
// we see more cases needing cluster scoping in the future.
|
||||
/*edge_filter=*/nullptr);
|
||||
}
|
||||
|
||||
// This preserves the parallelism between pipeline stages. For example, below
|
||||
// is a typical pattern of input pipelining in Tensorflow and this heuristic
|
||||
// ensures Node_X and Node_Y are put into different clusters. Without the
|
||||
// heuristic, they may be put into the same cluster and it can introduce
|
||||
// artificial dependencies and incur great performance loss. In this example,
|
||||
// Node_Y becomes dependent on IteratorGetNext and the latencies add up if
|
||||
// Node_X and Node_Y are in the same cluster.
|
||||
//
|
||||
// IteratorGetNext -> Node_X -> Stage
|
||||
//
|
||||
// Unstage -> Node_Y
|
||||
//
|
||||
Status ClusterScopingPassImpl::ScopingForPipelineStages() {
|
||||
for (Node* n : graph_->nodes()) {
|
||||
DCHECK(n);
|
||||
if (n->type_string() == "Unstage") {
|
||||
AddScopeToAllTransitiveSuccessors(n);
|
||||
}
|
||||
if (n->type_string() == "Stage") {
|
||||
AddScopeToAllTransitivePredecessors(n);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClusterScopingPassImpl::Run() {
|
||||
if (global_jit_level_ == OptimizerOptions::OFF) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return ScopingForPipelineStages();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status ClusterScopingPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
Graph* graph = options.graph->get();
|
||||
|
||||
return ClusterScopingPassImpl{graph, GetGlobalJitLevelForGraph(options)}
|
||||
.Run();
|
||||
}
|
||||
} // namespace tensorflow
|
38
tensorflow/compiler/jit/cluster_scoping_pass.h
Normal file
38
tensorflow/compiler/jit/cluster_scoping_pass.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_CLUSTER_SCOPING_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_CLUSTER_SCOPING_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This pass adds scopes to nodes in the _XlaInternalScope attribute to guide
|
||||
// the later clustering passes. A major reason to do this is to prevent the
|
||||
// clustering from losing critical parallelism in the Tensorflow graph, which
|
||||
// can incur great performance degradation.
|
||||
//
|
||||
// This pass must be run before MarkForCompilationPass, as it stores the
|
||||
// scoping information that MarkForCompilationPass will need to respect for
|
||||
// clustering decision.
|
||||
class ClusterScopingPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_CLUSTER_SCOPING_PASS_H_
|
183
tensorflow/compiler/jit/cluster_scoping_pass_test.cc
Normal file
183
tensorflow/compiler/jit/cluster_scoping_pass_test.cc
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status ClusterScoping(std::unique_ptr<Graph>* graph) {
|
||||
FixupSourceAndSinkEdges(graph->get());
|
||||
|
||||
GraphOptimizationPassOptions opt_options;
|
||||
opt_options.graph = graph;
|
||||
FunctionDefLibrary fdef_lib;
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
opt_options.flib_def = &flib_def;
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
session_options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_global_jit_level(OptimizerOptions::ON_2);
|
||||
opt_options.session_options = &session_options;
|
||||
|
||||
ClusterScopingPass pass;
|
||||
return pass.Run(opt_options);
|
||||
}
|
||||
|
||||
absl::flat_hash_map<string, string> GetXlaInternalScopes(const Graph& graph) {
|
||||
absl::flat_hash_map<string, string> scopes;
|
||||
for (Node* node : graph.nodes()) {
|
||||
string scope;
|
||||
if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) {
|
||||
scopes[node->name()] = scope;
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "_XlaInternalScopes:";
|
||||
for (const auto& p : scopes) {
|
||||
VLOG(2) << " " << p.first << " -> " << p.second;
|
||||
}
|
||||
}
|
||||
return scopes;
|
||||
}
|
||||
|
||||
Node* BuildStageNode(GraphDefBuilder& builder, string name,
|
||||
std::initializer_list<DataType> dtypes,
|
||||
absl::Span<const ops::NodeOut> values) {
|
||||
auto opts = builder.opts()
|
||||
.WithName(std::move(name))
|
||||
.WithAttr("dtypes", std::move(dtypes));
|
||||
if (opts.HaveError()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
NodeBuilder node_builder(name, "Stage", opts.op_registry());
|
||||
node_builder.Input(values);
|
||||
return opts.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, StagePipelinePreserved) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
{
|
||||
// Graph:
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// a -> add0 (ClusterX) -> relu0 (ClusterX) -> stage
|
||||
//
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// unstage -> add1 (ClusterY) -> relu1 (ClusterY)
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("a")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("b")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* unstage = ops::SourceOp(
|
||||
"Unstage",
|
||||
builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
|
||||
|
||||
Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
|
||||
Node* add1 =
|
||||
ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
|
||||
Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
|
||||
ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
|
||||
BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0});
|
||||
|
||||
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(ClusterScoping(&graph));
|
||||
|
||||
auto scopes = GetXlaInternalScopes(*graph);
|
||||
EXPECT_NE(scopes["add0"], scopes["add1"]);
|
||||
EXPECT_EQ(scopes["add0"], scopes["relu0"]);
|
||||
EXPECT_EQ(scopes["add1"], scopes["relu1"]);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
{
|
||||
// Graph:
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// a -> add0 (ClusterA) -> relu0 (ClusterB) -> stage
|
||||
//
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// unstage -> add1 (ClusterC) -> relu1 (ClusterD)
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("a")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("b")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* unstage = ops::SourceOp(
|
||||
"Unstage",
|
||||
builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
|
||||
|
||||
// Intentionally give add0 and add1 the same initial scope but they should
|
||||
// be separated by the ClusterScopingPass.
|
||||
Node* add0 = ops::BinaryOp("Add", a, b,
|
||||
builder.opts().WithName("add0").WithAttr(
|
||||
kXlaInternalScopeAttr, "ClusterA"));
|
||||
Node* add1 = ops::BinaryOp("Add", unstage, b,
|
||||
builder.opts().WithName("add1").WithAttr(
|
||||
kXlaInternalScopeAttr, "ClusterA"));
|
||||
Node* relu0 = ops::UnaryOp("Relu", add0,
|
||||
builder.opts().WithName("relu0").WithAttr(
|
||||
kXlaInternalScopeAttr, "ClusterB"));
|
||||
ops::UnaryOp("Relu", add1,
|
||||
builder.opts().WithName("relu1").WithAttr(
|
||||
kXlaInternalScopeAttr, "ClusterD"));
|
||||
BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0});
|
||||
|
||||
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(ClusterScoping(&graph));
|
||||
|
||||
auto scopes = GetXlaInternalScopes(*graph);
|
||||
EXPECT_NE(scopes["add0"], scopes["add1"]);
|
||||
EXPECT_NE(scopes["add0"], scopes["relu0"]);
|
||||
EXPECT_NE(scopes["add1"], scopes["relu1"]);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -37,6 +37,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
@ -44,6 +46,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
@ -83,7 +86,7 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap
|
||||
RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
|
||||
@ -98,12 +101,14 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
}
|
||||
}
|
||||
stack_trace.emplace_back(StackFrameView{node.name(), ""});
|
||||
std::vector<UncompilableNodeInfo> uncompilable_nodes;
|
||||
IsCompilableNode(node, lib_runtime, &stack_trace, &uncompilable_nodes);
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
|
||||
IsCompilableNode(node, lib_runtime, &stack_trace,
|
||||
/*encapsulating_function=*/nullptr, &uncompilable_nodes);
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap
|
||||
RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<RecursiveCompilabilityChecker::StackFrame>*
|
||||
@ -118,8 +123,10 @@ RecursiveCompilabilityChecker::FindUncompilableNodes(
|
||||
}
|
||||
}
|
||||
stack_trace.emplace_back(StackFrameView{call_def.name(), ""});
|
||||
std::vector<UncompilableNodeInfo> uncompilable_nodes;
|
||||
IsCompilableCall(call_def, lib_runtime, &stack_trace, &uncompilable_nodes);
|
||||
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
|
||||
IsCompilableCall(call_def, lib_runtime, &stack_trace,
|
||||
/*encapsulating_function=*/nullptr, &uncompilable_nodes);
|
||||
return uncompilable_nodes;
|
||||
}
|
||||
|
||||
@ -154,16 +161,18 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const {
|
||||
bool RecursiveCompilabilityChecker::IsCompilableIf(
|
||||
const Node& if_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
bool is_compilable = true;
|
||||
is_compilable &= ExtractNodeDefAndCheckCompilability(
|
||||
if_node, "then_branch", "if_then", lib_runtime, stack_trace,
|
||||
uncompilable_nodes);
|
||||
if_node, "then_branch", "if_then", encapsulating_function, lib_runtime,
|
||||
stack_trace, uncompilable_nodes);
|
||||
if (!uncompilable_nodes && !is_compilable) return is_compilable;
|
||||
|
||||
is_compilable &= ExtractNodeDefAndCheckCompilability(
|
||||
if_node, "else_branch", "if_else", lib_runtime, stack_trace,
|
||||
uncompilable_nodes);
|
||||
if_node, "else_branch", "if_else", encapsulating_function, lib_runtime,
|
||||
stack_trace, uncompilable_nodes);
|
||||
|
||||
return is_compilable;
|
||||
}
|
||||
@ -174,37 +183,43 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
|
||||
bool RecursiveCompilabilityChecker::IsCompilableWhile(
|
||||
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
bool is_compilable = true;
|
||||
is_compilable &= ExtractNodeDefAndCheckCompilability(
|
||||
while_node, "cond", "while_cond", lib_runtime, stack_trace,
|
||||
uncompilable_nodes);
|
||||
while_node, "cond", "while_cond", encapsulating_function, lib_runtime,
|
||||
stack_trace, uncompilable_nodes);
|
||||
|
||||
if (!uncompilable_nodes && !is_compilable) return is_compilable;
|
||||
|
||||
is_compilable &= ExtractNodeDefAndCheckCompilability(
|
||||
while_node, "body", "while_body", lib_runtime, stack_trace,
|
||||
uncompilable_nodes);
|
||||
while_node, "body", "while_body", encapsulating_function, lib_runtime,
|
||||
stack_trace, uncompilable_nodes);
|
||||
|
||||
return is_compilable;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
|
||||
const Node& node, const std::string& attr_name,
|
||||
const std::string& call_name, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::string& call_name, NameAttrList* encapsulating_function,
|
||||
FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
NodeDef call;
|
||||
call.set_name(call_name);
|
||||
if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) {
|
||||
const auto uncompilable_reason = absl::StrCat(
|
||||
"missing '", attr_name, "' attribute from node", node.name());
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason
|
||||
<< ".";
|
||||
return false;
|
||||
}
|
||||
if (!IsCompilableCall(call, lib_runtime, stack_trace, uncompilable_nodes)) {
|
||||
if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes)) {
|
||||
VLOG(2) << "Rejecting node " << node.name()
|
||||
<< ": can't compile : " << call.op();
|
||||
return false;
|
||||
@ -218,24 +233,33 @@ bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
|
||||
bool RecursiveCompilabilityChecker::IsCompilableCall(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
if (stack_trace->size() > kMaxRecursionDepth) {
|
||||
std::string uncompilable_reason = "function depth limit exceeded";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason
|
||||
<< ".";
|
||||
return false;
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status = InstantiateFunctionCall(call_def, lib_runtime, &handle);
|
||||
if (!status.ok()) {
|
||||
Status s;
|
||||
NameAttrList function;
|
||||
s = NameAndAttrsFromFunctionCall(call_def, &function);
|
||||
if (s.ok()) {
|
||||
s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
|
||||
&handle);
|
||||
}
|
||||
|
||||
if (!s.ok()) {
|
||||
std::string uncompilable_reason = "could not instantiate call";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
|
||||
<< uncompilable_reason << " : " << status;
|
||||
<< uncompilable_reason << " : " << s;
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -244,9 +268,9 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
|
||||
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
|
||||
bool is_compilable = true;
|
||||
for (const Node* node : fbody->graph->op_nodes()) {
|
||||
stack_trace->emplace_back(StackFrameView{node->name(), call_def.op()});
|
||||
is_compilable &=
|
||||
IsCompilableNode(*node, lib_runtime, stack_trace, uncompilable_nodes);
|
||||
stack_trace->emplace_back(StackFrameView{node->name(), function.name()});
|
||||
is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
|
||||
&function, uncompilable_nodes);
|
||||
stack_trace->pop_back();
|
||||
if (!uncompilable_nodes && !is_compilable) return is_compilable;
|
||||
}
|
||||
@ -263,20 +287,28 @@ bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const {
|
||||
bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
|
||||
// b/128001705: SelfAdjointEigV2 and Svd performance issues.
|
||||
// b/135640736: MatrixInverse performance issues.
|
||||
// https://github.com/tensorflow/tensorflow/pull/31012:
|
||||
// ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes
|
||||
// create convolutions too large for CuDNN to handle.
|
||||
return node.type_string() == "SelfAdjointEigV2" ||
|
||||
node.type_string() == "Svd" || node.type_string() == "Qr" ||
|
||||
node.type_string() == "MatrixInverse";
|
||||
node.type_string() == "MatrixInverse" ||
|
||||
node.type_string() == "ResizeNearestNeighbor" ||
|
||||
node.type_string() == "ResizeBilinear" ||
|
||||
node.type_string() == "ResizeBilinearGrad";
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const {
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
auto stack_depth = stack_trace->size();
|
||||
if (node.IsSource() || node.IsSink()) {
|
||||
absl::string_view uncompilable_reason = "source or sink node";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -287,7 +319,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
(node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
|
||||
absl::string_view uncompilable_reason = "top level _Arg or _Retval";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -299,33 +331,35 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
absl::string_view uncompilable_reason =
|
||||
"_scoped_allocator or _forward_from attribute";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
|
||||
if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
|
||||
uncompilable_nodes)) {
|
||||
encapsulating_function, uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported function");
|
||||
return false;
|
||||
}
|
||||
} else if (!HasXLAKernel(node)) {
|
||||
absl::string_view uncompilable_reason = "unsupported op";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.type_string() == "While" &&
|
||||
!IsCompilableWhile(node, lib_runtime, stack_trace, uncompilable_nodes)) {
|
||||
if (node.IsWhileNode() &&
|
||||
!IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported while");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.IsIfNode() &&
|
||||
!IsCompilableIf(node, lib_runtime, stack_trace, uncompilable_nodes)) {
|
||||
!IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported if");
|
||||
return false;
|
||||
}
|
||||
@ -334,7 +368,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
IsStatefulRandomOp(node.type_string())) {
|
||||
absl::string_view uncompilable_reason = "stateful random op";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -342,7 +376,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
|
||||
absl::string_view uncompilable_reason = "not allowed control trigger";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -351,7 +385,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
IsAssertOrCheckNumerics(node.type_string())) {
|
||||
absl::string_view uncompilable_reason = "Assert or CheckNumerics";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -360,7 +394,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
OpProducesOrConsumesVariant(node)) {
|
||||
absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -368,7 +402,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
|
||||
absl::string_view uncompilable_reason = "Stack op";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -376,7 +410,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
|
||||
absl::string_view uncompilable_reason = "TensorArray op";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -386,7 +420,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
absl::string_view uncompilable_reason =
|
||||
"resource variable op in called function";
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -394,16 +428,22 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
|
||||
absl::string_view uncompilable_reason =
|
||||
"operation with numerical accuracy issues";
|
||||
BroadcastOptimizationRemark(XlaOptimizationRemark::INACCURATE_OPERATION,
|
||||
node.DebugString())
|
||||
.IgnoreError();
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
|
||||
absl::string_view uncompilable_reason = "slow operation";
|
||||
BroadcastOptimizationRemark(XlaOptimizationRemark::SLOW_OPERATION,
|
||||
node.DebugString())
|
||||
.IgnoreError();
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
uncompilable_nodes);
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
LogNotCompilable(node, uncompilable_reason);
|
||||
return false;
|
||||
}
|
||||
@ -432,8 +472,9 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
/*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
|
||||
const absl::string_view reason,
|
||||
const std::vector<StackFrameView>& stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_node_list) {
|
||||
if (!uncompilable_node_list) return;
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) {
|
||||
if (!uncompilable_nodes) return;
|
||||
|
||||
UncompilableNodeInfo node_info;
|
||||
node_info.uncompilable_reason = std::string(reason);
|
||||
@ -445,7 +486,20 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
});
|
||||
|
||||
node_info.name = std::string(stack_trace.back().name);
|
||||
(*uncompilable_node_list).push_back(std::move(node_info));
|
||||
auto function =
|
||||
encapsulating_function ? *encapsulating_function : NameAttrList();
|
||||
auto function_identifier = function.ShortDebugString();
|
||||
|
||||
auto it = uncompilable_nodes->find(function_identifier);
|
||||
if (it == uncompilable_nodes->end()) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompileable_node_info{std::move(node_info)};
|
||||
uncompilable_nodes->emplace(
|
||||
std::move(function_identifier),
|
||||
std::make_pair(function, std::move(uncompileable_node_info)));
|
||||
} else {
|
||||
it->second.second.emplace_back(std::move(node_info));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
@ -129,19 +130,35 @@ class RecursiveCompilabilityChecker {
|
||||
const DeviceType* jit_device_type)
|
||||
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
|
||||
|
||||
// Returns a list of uncompilable nodes. When `node` is inside a function
|
||||
// body, users can set `node_stack_trace` to provide an additional
|
||||
// context for `node`'s placement within the outer most graph.
|
||||
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
|
||||
using UncompilableNodesMap =
|
||||
std::map<std::string,
|
||||
std::pair<NameAttrList, std::vector<UncompilableNodeInfo>>>;
|
||||
|
||||
// Returns a map where the key is the function identifier(short debug
|
||||
// string) of the function encapsulating the uncompilable nodes, and the
|
||||
// value is a pair of NameAttrList of the function and a vector of
|
||||
// uncompilable node info. When uncompilable node is not inside any
|
||||
// function call nodes, then key is a ShortDebugString() of an empty
|
||||
// NameAttrList.
|
||||
//
|
||||
// Also, when `node` is inside a function body, users can set
|
||||
// `node_stack_trace` to provide an additional context for `node`'s
|
||||
// placement within the outer most graph.
|
||||
UncompilableNodesMap FindUncompilableNodes(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
|
||||
|
||||
// Returns a list of uncompilable nodes in `call_def` that cannot be
|
||||
// compiled by XLA. It is assumed that `call_def` is a call operation.
|
||||
// When `node` is inside a function body, users can set
|
||||
// Returns a map where the key is the function identifier(short debug
|
||||
// string) of the function encapsulating the uncompilable nodes, and the
|
||||
// value is a pair of NameAttrList of the function and a vector of
|
||||
// uncompilable node info. When uncompilable node is not inside any
|
||||
// function call nodes, then key is a ShortDebugString() of an empty
|
||||
// NameAttrList.
|
||||
//
|
||||
// Also, when `node` is inside a function body, users can set
|
||||
// `node_stack_trace` to provide an additional context for `node`'s
|
||||
// placement within the outer most graph.
|
||||
std::vector<UncompilableNodeInfo> FindUncompilableNodes(
|
||||
UncompilableNodesMap FindUncompilableNodes(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::vector<StackFrame>* node_stack_trace = nullptr) const;
|
||||
|
||||
@ -176,27 +193,31 @@ class RecursiveCompilabilityChecker {
|
||||
bool IsCompilableNode(
|
||||
const Node& node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
|
||||
NameAttrList* encapsulating_function = nullptr,
|
||||
UncompilableNodesMap* uncompilable_nodes = nullptr) const;
|
||||
bool IsCompilableCall(
|
||||
const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes = nullptr) const;
|
||||
bool IsCompilableIf(
|
||||
const Node& if_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
|
||||
bool IsCompilableWhile(
|
||||
const Node& while_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
|
||||
NameAttrList* encapsulating_function = nullptr,
|
||||
UncompilableNodesMap* uncompilable_nodes = nullptr) const;
|
||||
bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
bool IsCompilableWhile(const Node& while_node,
|
||||
FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
|
||||
// Returns compilability of node def retrieved from `node`'s attribute with
|
||||
// name `attr_name`.
|
||||
bool ExtractNodeDefAndCheckCompilability(
|
||||
const Node& node, const std::string& attr_name,
|
||||
const std::string& call_name, FunctionLibraryRuntime* lib_runtime,
|
||||
const std::string& call_name, NameAttrList* encapsulating_function,
|
||||
FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_nodes) const;
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
|
||||
bool IsStackOp(const Node& node) const {
|
||||
const XlaResourceOpInfo* op_info =
|
||||
@ -231,7 +252,8 @@ class RecursiveCompilabilityChecker {
|
||||
static void MaybeMarkUncompilableNode(
|
||||
const absl::string_view reason,
|
||||
const std::vector<StackFrameView>& stack_trace,
|
||||
std::vector<UncompilableNodeInfo>* uncompilable_node_list);
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes_map);
|
||||
|
||||
// Make sure we don't recurse infinitely on recursive functions.
|
||||
const int kMaxRecursionDepth = 10;
|
||||
|
@ -21,8 +21,10 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -117,10 +119,15 @@ TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) {
|
||||
const auto uncompilable_nodes =
|
||||
checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
|
||||
ASSERT_EQ(1, uncompilable_nodes.size());
|
||||
const auto& node_info = uncompilable_nodes.at(0);
|
||||
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
|
||||
ASSERT_EQ(1, node_info.stack_trace.size());
|
||||
ASSERT_EQ("", node_info.stack_trace.at(0).function_name);
|
||||
auto node_info_it =
|
||||
uncompilable_nodes.find(NameAttrList().ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
|
||||
const auto& uncompilable_nodes_inside_function = node_info_it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_nodes_inside_function.size());
|
||||
const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0);
|
||||
EXPECT_EQ("unsupported op", uncompilable_node_info.uncompilable_reason);
|
||||
ASSERT_EQ(1, uncompilable_node_info.stack_trace.size());
|
||||
ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name);
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
|
||||
@ -147,12 +154,18 @@ TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
|
||||
checker_->FindUncompilableNodes(*functional_node, flib_runtime);
|
||||
|
||||
EXPECT_EQ(1, uncompilable_nodes.size());
|
||||
const auto& node_info = uncompilable_nodes.at(0);
|
||||
NameAttrList function;
|
||||
function.set_name(kUncompilableFunctionName);
|
||||
const auto node_info_it =
|
||||
uncompilable_nodes.find(function.ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
|
||||
const auto& uncompilable_node_list = node_info_it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_node_list.size());
|
||||
const auto& node_info = uncompilable_node_list.at(0);
|
||||
const auto& node_stack = node_info.stack_trace;
|
||||
ASSERT_EQ(2, node_stack.size());
|
||||
EXPECT_EQ("D", node_stack.at(0).name);
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name);
|
||||
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
|
||||
EXPECT_EQ("unsupported op", node_info.uncompilable_reason);
|
||||
}
|
||||
@ -212,7 +225,15 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
|
||||
checker_->FindUncompilableNodes(**while_node_it, flib_runtime);
|
||||
ASSERT_EQ(1, uncompilable_nodes.size());
|
||||
|
||||
const auto& node_info = uncompilable_nodes.at(0);
|
||||
NameAttrList function;
|
||||
function.set_name(kUncompilableFunctionName);
|
||||
const auto node_info_it =
|
||||
uncompilable_nodes.find(function.ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
|
||||
const auto& uncompilable_node_list = node_info_it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_node_list.size());
|
||||
const auto& node_info = uncompilable_node_list.at(0);
|
||||
|
||||
const auto& node_stack = node_info.stack_trace;
|
||||
ASSERT_EQ(2, node_stack.size());
|
||||
const auto& stacktrace_first_node_info = node_stack.at(0);
|
||||
@ -280,7 +301,14 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
checker_->FindUncompilableNodes(**if_node_it, flib_runtime);
|
||||
ASSERT_EQ(2, uncompilable_nodes.size());
|
||||
|
||||
const auto& uncompilable_node_one = uncompilable_nodes.at(0);
|
||||
NameAttrList function_one;
|
||||
function_one.set_name(kUncompilableFunctionName);
|
||||
auto it = uncompilable_nodes.find(function_one.ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), it);
|
||||
|
||||
const auto& uncompilable_node_list = it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_node_list.size());
|
||||
const auto& uncompilable_node_one = uncompilable_node_list.at(0);
|
||||
const auto& node_one_stack = uncompilable_node_one.stack_trace;
|
||||
|
||||
ASSERT_EQ(2, node_one_stack.size());
|
||||
@ -296,7 +324,14 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name);
|
||||
EXPECT_EQ("unsupported op", uncompilable_node_one.uncompilable_reason);
|
||||
|
||||
const auto& uncompilable_node_two = uncompilable_nodes.at(1);
|
||||
NameAttrList function_two;
|
||||
function_two.set_name(kUncompilableFunctionTwoName);
|
||||
it = uncompilable_nodes.find(function_two.ShortDebugString());
|
||||
ASSERT_NE(uncompilable_nodes.end(), it);
|
||||
|
||||
const auto& uncompilable_node_two_list = it->second.second;
|
||||
ASSERT_EQ(1, uncompilable_node_two_list.size());
|
||||
const auto& uncompilable_node_two = uncompilable_node_two_list.at(0);
|
||||
const auto& node_two_stack = uncompilable_node_two.stack_trace;
|
||||
ASSERT_EQ(2, node_two_stack.size());
|
||||
const auto& node_two_stacktrace_first_node = node_two_stack.at(0);
|
||||
|
@ -18,6 +18,12 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kXlaCompileAttr = "_XlaCompile";
|
||||
|
||||
// User-provided through jit_scope APIs. Effective only when auto_jit is OFF.
|
||||
const char* const kXlaScopeAttr = "_XlaScope";
|
||||
|
||||
// Automatically inserted by auto_jit to guide clustering results. Effective
|
||||
// only when auto_jit is ON.
|
||||
const char* const kXlaInternalScopeAttr = "_XlaInternalScope";
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -24,6 +24,7 @@ namespace tensorflow {
|
||||
// Name of attribute used to tag operators for compilation with XLA
|
||||
extern const char* const kXlaCompileAttr; // "_XlaCompile"
|
||||
extern const char* const kXlaScopeAttr; // "_XlaScope"
|
||||
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -1317,7 +1317,7 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
bool IsXlaCompiledKernel(const Node& node) {
|
||||
bool is_compiled = false;
|
||||
bool has_compilation_attr =
|
||||
GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
|
||||
TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
|
||||
is_compiled;
|
||||
return has_compilation_attr ? is_compiled : false;
|
||||
}
|
||||
|
@ -245,8 +245,8 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
// while iterating.
|
||||
std::vector<Node*> launch_nodes;
|
||||
for (Node* n : graph->nodes()) {
|
||||
string name;
|
||||
if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) {
|
||||
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr);
|
||||
if (!name.empty()) {
|
||||
launch_nodes.push_back(n);
|
||||
}
|
||||
}
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
@ -33,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
@ -369,7 +369,8 @@ xla::StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
|
||||
return new_def;
|
||||
}
|
||||
|
||||
Status ValidateOutsideCompilationCallNode(Node* call_node) {
|
||||
TF_ATTRIBUTE_NOINLINE Status
|
||||
ValidateOutsideCompilationCallNode(Node* call_node) {
|
||||
// DT_INT64 as input/output for outside compilation is not supported yet:
|
||||
// b/120809951.
|
||||
for (const Edge* e : call_node->in_edges()) {
|
||||
@ -402,7 +403,7 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) {
|
||||
}
|
||||
|
||||
// Replace outside compilation function call node with XlaHostCompute node.
|
||||
xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
|
||||
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> ReplaceOutsideCompilationCallNode(
|
||||
Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
|
||||
const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
|
||||
// Build XlaHostCompute NodeDef.
|
||||
@ -440,7 +441,7 @@ Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
|
||||
n->ClearAttr(attr_name);
|
||||
n->AddAttr(attr_name, branch_func);
|
||||
}
|
||||
} else if (n->type_string() == "While") {
|
||||
} else if (n->IsWhileNode()) {
|
||||
for (const string& attr_name : std::vector<string>{"cond", "body"}) {
|
||||
NameAttrList branch_func;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
|
||||
@ -523,16 +524,14 @@ xla::StatusOr<std::vector<DataType>> UpdateTypesAttribute(
|
||||
|
||||
// Add edges from lifted outside compilation argument nodes to `n` in Graph `g`.
|
||||
void AddEdgesFromOutsideCompilationNodes(
|
||||
const int original_arg_count, const std::vector<DataType>& data_types,
|
||||
const std::vector<std::pair<Node*, Node*>>&
|
||||
lifted_arg_nodes_and_outside_compilation_nodes,
|
||||
Graph* g, Node* n) {
|
||||
const int original_arg_count, const int arg_to_input_edge_offset,
|
||||
const std::vector<DataType>& data_types,
|
||||
const std::vector<Node*>& outside_compilation_nodes, Graph* g, Node* n) {
|
||||
// Add edges from outside compilation nodes to While node.
|
||||
for (int i = original_arg_count; i < data_types.size(); i++) {
|
||||
Node* outside_compilation_node =
|
||||
lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
|
||||
.second;
|
||||
g->AddEdge(outside_compilation_node, 0, n, i);
|
||||
outside_compilation_nodes[i - original_arg_count];
|
||||
g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@ -573,14 +572,15 @@ Status AddMatchingRetvalNode(const FunctionBody& function_body,
|
||||
|
||||
void ReplaceLiftedArgNodePlaceholderWithArg(
|
||||
const FunctionBody& function_body, const int original_arg_count,
|
||||
const int arg_idx,
|
||||
const std::vector<std::pair<Node*, Node*>>&
|
||||
lifted_arg_nodes_and_outside_compilation_nodes,
|
||||
const int arg_idx, const std::vector<Node*>& lifted_arg_nodes,
|
||||
Node* arg_node) {
|
||||
Node* lifted_arg_node =
|
||||
lifted_arg_nodes_and_outside_compilation_nodes[arg_idx -
|
||||
original_arg_count]
|
||||
.first;
|
||||
Node* lifted_arg_node = lifted_arg_nodes[arg_idx - original_arg_count];
|
||||
// This might happen because lifted_arg_node only exists in one branch of an
|
||||
// If node, and we are handling the other branch.
|
||||
if (!lifted_arg_node) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const Edge* e : lifted_arg_node->out_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
function_body.graph->AddControlEdge(arg_node, e->dst());
|
||||
@ -588,7 +588,6 @@ void ReplaceLiftedArgNodePlaceholderWithArg(
|
||||
function_body.graph->AddEdge(arg_node, 0, e->dst(), e->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
function_body.graph->RemoveNode(lifted_arg_node);
|
||||
}
|
||||
|
||||
@ -597,7 +596,7 @@ void ReplaceLiftedArgNodePlaceholderWithArg(
|
||||
Status PostprocessLiftedArgsForWhile(
|
||||
const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
|
||||
Graph* g, Node* n, FunctionLibraryDefinition* fld) {
|
||||
TF_RET_CHECK(n->type_string() == "While");
|
||||
TF_RET_CHECK(n->IsWhileNode());
|
||||
|
||||
// Check if there is any lifted args in body function.
|
||||
NameAttrList body_func;
|
||||
@ -629,12 +628,25 @@ Status PostprocessLiftedArgsForWhile(
|
||||
n));
|
||||
|
||||
// Add edges from outside compilation nodes to While node.
|
||||
AddEdgesFromOutsideCompilationNodes(
|
||||
original_arg_count, data_types,
|
||||
lifted_arg_nodes_and_outside_compilation_nodes, g, n);
|
||||
std::vector<Node*> outside_compilation_nodes;
|
||||
std::transform(
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.begin(),
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.end(),
|
||||
std::back_inserter(outside_compilation_nodes),
|
||||
[](const std::pair<Node*, Node*>& pair) { return pair.second; });
|
||||
AddEdgesFromOutsideCompilationNodes(original_arg_count,
|
||||
/*arg_to_input_edge_offset=*/0,
|
||||
data_types, outside_compilation_nodes, g,
|
||||
n);
|
||||
|
||||
// In body_graph, create new _Arg/_Retval nodes, and replace lifted arg
|
||||
// nodes with the new _Arg nodes.
|
||||
std::vector<Node*> lifted_arg_nodes;
|
||||
std::transform(
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.begin(),
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.end(),
|
||||
std::back_inserter(lifted_arg_nodes),
|
||||
[](const std::pair<Node*, Node*>& pair) { return pair.first; });
|
||||
for (int i = original_arg_count; i < data_types.size(); i++) {
|
||||
TF_ASSIGN_OR_RETURN(Node * arg_node,
|
||||
AddOutsideCompilationInputArgToFunctionBody(
|
||||
@ -644,8 +656,7 @@ Status PostprocessLiftedArgsForWhile(
|
||||
AddMatchingRetvalNode(*body_function_body, i, data_types[i], arg_node));
|
||||
|
||||
ReplaceLiftedArgNodePlaceholderWithArg(
|
||||
*body_function_body, original_arg_count, i,
|
||||
lifted_arg_nodes_and_outside_compilation_nodes, arg_node);
|
||||
*body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
|
||||
}
|
||||
|
||||
FunctionDef rewritten_body_function_def;
|
||||
@ -682,6 +693,219 @@ Status PostprocessLiftedArgsForWhile(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PostprocessLiftedArgsForIf(
|
||||
const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
|
||||
Graph* g, Node* n, FunctionLibraryDefinition* fld) {
|
||||
TF_RET_CHECK(n->IsIfNode());
|
||||
|
||||
NameAttrList then_branch_func;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func));
|
||||
const FunctionDef* then_branch_function_def =
|
||||
fld->Find(then_branch_func.name());
|
||||
TF_RET_CHECK(then_branch_function_def);
|
||||
|
||||
NameAttrList else_branch_func;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "else_branch", &else_branch_func));
|
||||
const FunctionDef* else_branch_function_def =
|
||||
fld->Find(else_branch_func.name());
|
||||
TF_RET_CHECK(else_branch_function_def);
|
||||
|
||||
// Nothing to do if neither branch contains any lifted arguments.
|
||||
if (!HasLiftedArgs(*then_branch_function_def) &&
|
||||
!HasLiftedArgs(*else_branch_function_def)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionBody> then_branch_function_body;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
|
||||
*then_branch_function_def, AttrSlice(&then_branch_func.attr()), fld,
|
||||
&then_branch_function_body));
|
||||
|
||||
std::unique_ptr<FunctionBody> else_branch_function_body;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
|
||||
*else_branch_function_def, AttrSlice(&else_branch_func.attr()), fld,
|
||||
&else_branch_function_body));
|
||||
|
||||
// Then and else branches have same argument count and argument data types.
|
||||
int original_arg_count = then_branch_function_body->arg_nodes.size();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto then_branch_lifted_arg_nodes_and_outside_compilation_nodes,
|
||||
LiftedArgsAndOutsideCompilationNodesInFunctionBody(
|
||||
*then_branch_function_body, outside_compilation_attr_to_node));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto else_branch_lifted_arg_nodes_and_outside_compilation_nodes,
|
||||
LiftedArgsAndOutsideCompilationNodesInFunctionBody(
|
||||
*else_branch_function_body, outside_compilation_attr_to_node));
|
||||
|
||||
// Merge lifted args from then and else branches.
|
||||
std::vector<Node*> outside_compilation_nodes;
|
||||
std::vector<Node*> then_branch_lifted_arg_nodes;
|
||||
for (const auto& pair :
|
||||
then_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
|
||||
outside_compilation_nodes.push_back(pair.second);
|
||||
then_branch_lifted_arg_nodes.push_back(pair.first);
|
||||
}
|
||||
for (const auto& pair :
|
||||
else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
|
||||
if (std::find(outside_compilation_nodes.begin(),
|
||||
outside_compilation_nodes.end(),
|
||||
pair.second) == outside_compilation_nodes.end()) {
|
||||
outside_compilation_nodes.push_back(pair.second);
|
||||
// Then branch does not contain this lifted arg. Add an empty item to
|
||||
// then_branch_lifted_arg_nodes.
|
||||
then_branch_lifted_arg_nodes.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
// Reorder else_branch_lifted_arg_nodes_and_outside_compilation_nodes.
|
||||
std::vector<Node*> else_branch_lifted_arg_nodes(
|
||||
outside_compilation_nodes.size());
|
||||
for (const auto& pair :
|
||||
else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
|
||||
auto iter = std::find(outside_compilation_nodes.begin(),
|
||||
outside_compilation_nodes.end(), pair.second);
|
||||
TF_RET_CHECK(iter != outside_compilation_nodes.end());
|
||||
int index = iter - outside_compilation_nodes.begin();
|
||||
else_branch_lifted_arg_nodes[index] = pair.first;
|
||||
}
|
||||
|
||||
// Append lifted args' types to If node's Tin attribute.
|
||||
std::vector<DataType> data_types;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &data_types));
|
||||
for (Node* n : outside_compilation_nodes) {
|
||||
data_types.push_back(n->output_type(0));
|
||||
}
|
||||
n->ClearAttr("Tin");
|
||||
n->AddAttr("Tin", data_types);
|
||||
|
||||
// Add edges from outside compilation nodes to If node. If node's input #0
|
||||
// is predicate input, input #1 maps to _Arg #0 of branch functions, thus
|
||||
// arg_to_input_edge_offset is set to 1.
|
||||
AddEdgesFromOutsideCompilationNodes(original_arg_count,
|
||||
/*arg_to_input_edge_offset=*/1,
|
||||
data_types, outside_compilation_nodes, g,
|
||||
n);
|
||||
|
||||
for (int i = original_arg_count; i < data_types.size(); ++i) {
|
||||
TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node,
|
||||
AddOutsideCompilationInputArgToFunctionBody(
|
||||
*then_branch_function_body, i, data_types[i]));
|
||||
|
||||
ReplaceLiftedArgNodePlaceholderWithArg(
|
||||
*then_branch_function_body, original_arg_count, i,
|
||||
then_branch_lifted_arg_nodes, then_branch_arg_node);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Node * else_branch_arg_node,
|
||||
AddOutsideCompilationInputArgToFunctionBody(
|
||||
*else_branch_function_body, i, data_types[i]));
|
||||
|
||||
ReplaceLiftedArgNodePlaceholderWithArg(
|
||||
*else_branch_function_body, original_arg_count, i,
|
||||
else_branch_lifted_arg_nodes, else_branch_arg_node);
|
||||
}
|
||||
|
||||
FunctionDef rewritten_then_branch_function_def;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||
*then_branch_function_body->graph, then_branch_func.name(),
|
||||
HostGraphControlRetMapping, &rewritten_then_branch_function_def));
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_func.name(),
|
||||
rewritten_then_branch_function_def));
|
||||
|
||||
FunctionDef rewritten_else_branch_function_def;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(
|
||||
*else_branch_function_body->graph, else_branch_func.name(),
|
||||
HostGraphControlRetMapping, &rewritten_else_branch_function_def));
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_func.name(),
|
||||
rewritten_else_branch_function_def));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PostprocessLiftedArgsForCall(
|
||||
const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
|
||||
Graph* g, Node* n, FunctionLibraryDefinition* fld) {
|
||||
const FunctionDef* fdef = fld->Find(n->type_string());
|
||||
TF_RET_CHECK(fdef);
|
||||
|
||||
// Nothing to do if the function does not contain any lifted arguments.
|
||||
if (!HasLiftedArgs(*fdef)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionBody> fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, n->attrs(), fld, &fbody));
|
||||
|
||||
int original_arg_count = fbody->arg_nodes.size();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto lifted_arg_nodes_and_outside_compilation_nodes,
|
||||
LiftedArgsAndOutsideCompilationNodesInFunctionBody(
|
||||
*fbody, outside_compilation_attr_to_node));
|
||||
|
||||
// Append lifted args' types to call node's input data types.
|
||||
std::vector<DataType> data_types(n->input_types().begin(),
|
||||
n->input_types().end());
|
||||
for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
|
||||
Node* outside_compilation_node = pair.second;
|
||||
DataType data_type;
|
||||
TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
|
||||
outside_compilation_node->type_string() == "Placeholder");
|
||||
if (outside_compilation_node->IsIdentity()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
|
||||
}
|
||||
data_types.push_back(data_type);
|
||||
}
|
||||
|
||||
std::vector<Node*> lifted_arg_nodes;
|
||||
std::transform(
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.begin(),
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.end(),
|
||||
std::back_inserter(lifted_arg_nodes),
|
||||
[](const std::pair<Node*, Node*>& pair) { return pair.first; });
|
||||
for (int i = original_arg_count; i < data_types.size(); ++i) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Node * arg_node,
|
||||
AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i]));
|
||||
|
||||
ReplaceLiftedArgNodePlaceholderWithArg(*fbody, original_arg_count, i,
|
||||
lifted_arg_nodes, arg_node);
|
||||
}
|
||||
|
||||
FunctionDef rewritten_fdef;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
|
||||
HostGraphControlRetMapping,
|
||||
&rewritten_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(n->type_string(), rewritten_fdef));
|
||||
|
||||
// We need to recreate the node. Otherwise TF will not know n->num_inputs()
|
||||
// has increased.
|
||||
NodeDef node_def = n->def();
|
||||
for (int i = original_arg_count; i < data_types.size(); i++) {
|
||||
Node* outside_compilation_node =
|
||||
lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
|
||||
.second;
|
||||
node_def.add_input(absl::StrCat(outside_compilation_node->name(), ":", 0));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(n, ReplaceNode(g, n, node_def));
|
||||
|
||||
// Add edges from outside compilation nodes to call node.
|
||||
std::vector<Node*> outside_compilation_nodes;
|
||||
std::transform(
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.begin(),
|
||||
lifted_arg_nodes_and_outside_compilation_nodes.end(),
|
||||
std::back_inserter(outside_compilation_nodes),
|
||||
[](const std::pair<Node*, Node*>& pair) { return pair.second; });
|
||||
AddEdgesFromOutsideCompilationNodes(original_arg_count,
|
||||
/*arg_to_input_edge_offset=*/0,
|
||||
data_types, outside_compilation_nodes, g,
|
||||
n);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a mapping from outside compilation cluster name to lifted argument
|
||||
// placeholder.
|
||||
xla::StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode(
|
||||
@ -690,10 +914,9 @@ xla::StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode(
|
||||
for (Node* n : g.op_nodes()) {
|
||||
bool is_lifted_arg;
|
||||
string outside_compilation_attr;
|
||||
if (GetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg).ok() &&
|
||||
GetNodeAttr(n->def(), "_xla_outside_compilation",
|
||||
&outside_compilation_attr)
|
||||
.ok()) {
|
||||
if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) &&
|
||||
TryGetNodeAttr(n->def(), "_xla_outside_compilation",
|
||||
&outside_compilation_attr)) {
|
||||
TF_RET_CHECK(is_lifted_arg);
|
||||
TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder");
|
||||
outside_compilation_attr_to_node[outside_compilation_attr] = n;
|
||||
@ -707,15 +930,34 @@ Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) {
|
||||
TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node,
|
||||
OutsideCompilationAttrToNode(*g));
|
||||
|
||||
std::vector<Node*> call_nodes;
|
||||
for (Node* n : g->op_nodes()) {
|
||||
if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (n->type_string() == "While") {
|
||||
if (n->IsWhileNode()) {
|
||||
TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile(
|
||||
outside_compilation_attr_to_node, g, n, fld));
|
||||
}
|
||||
|
||||
if (n->IsIfNode()) {
|
||||
TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf(
|
||||
outside_compilation_attr_to_node, g, n, fld));
|
||||
}
|
||||
|
||||
// Outside compilation host side function call will always be direct
|
||||
// function call nodes.
|
||||
// Function call nodes need to be handled separately because we rewrite
|
||||
// nodes in `PostprocessLiftedArgsForCall`.
|
||||
if (fld->Contains(n->type_string())) {
|
||||
call_nodes.push_back(n);
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* n : call_nodes) {
|
||||
TF_RETURN_IF_ERROR(PostprocessLiftedArgsForCall(
|
||||
outside_compilation_attr_to_node, g, n, fld));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -1065,9 +1307,9 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
|
||||
}
|
||||
|
||||
// Builds XlaSendToHost node which sends cond predicate to host.
|
||||
xla::StatusOr<Node*> BuildSendIfPredNode(const string& name,
|
||||
const string& host_transfer_key,
|
||||
Node* pred_node, Graph* g) {
|
||||
TF_ATTRIBUTE_NOINLINE xla::StatusOr<Node*> BuildSendIfPredNode(
|
||||
const string& name, const string& host_transfer_key, Node* pred_node,
|
||||
Graph* g) {
|
||||
NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
|
||||
send_pred_builder.Attr("Tinput", DT_BOOL);
|
||||
send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
|
||||
@ -1130,15 +1372,13 @@ Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
|
||||
}
|
||||
|
||||
// Builds host side graph for If node.
|
||||
Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name,
|
||||
const string& xla_cluster_name,
|
||||
const string& if_node_name,
|
||||
const string& host_transfer_key,
|
||||
const string& host_graph_func_name,
|
||||
FunctionLibraryDefinition* fld,
|
||||
const string& then_branch_host_func_name,
|
||||
const string& else_branch_host_func_name) {
|
||||
TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name, const string& xla_cluster_name,
|
||||
const string& if_node_name, const string& host_transfer_key,
|
||||
const string& host_graph_func_name, FunctionLibraryDefinition* fld,
|
||||
const string& then_branch_host_func_name,
|
||||
const string& else_branch_host_func_name) {
|
||||
Graph host_graph(fld);
|
||||
string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
|
||||
AttrValue device_ordinal_value;
|
||||
@ -1215,10 +1455,9 @@ Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name,
|
||||
}
|
||||
|
||||
// Rewrites loop cond to add a node which sends loop cond to host.
|
||||
Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld,
|
||||
const NameAttrList& loop_cond_func,
|
||||
const string& while_node_name,
|
||||
const string& host_transfer_key) {
|
||||
TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
|
||||
FunctionLibraryDefinition* fld, const NameAttrList& loop_cond_func,
|
||||
const string& while_node_name, const string& host_transfer_key) {
|
||||
// Instantiate the loop cond function.
|
||||
std::unique_ptr<FunctionBody> fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(loop_cond_func.name()),
|
||||
@ -1406,7 +1645,7 @@ Status RewriteHostWhileLoopBody(
|
||||
}
|
||||
|
||||
// Builds host side graph for while node.
|
||||
Status BuildHostGraphForWhileNode(
|
||||
TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name, const string& xla_cluster_name,
|
||||
const string& while_node_name, const string& host_transfer_key,
|
||||
@ -1503,10 +1742,6 @@ Status BuildHostGraphForFuncCallNode(
|
||||
call_builder.Attr(kXlaHasHostTransferAttrName, true);
|
||||
call_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
|
||||
call_builder.Attr(outside_compilation_attr_name, call_builder.node_name());
|
||||
// Make sure control outputs of this function call node will be respected when
|
||||
// this node is lowered.
|
||||
call_builder.Attr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
|
||||
true);
|
||||
NodeDef call_def;
|
||||
TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def));
|
||||
Status s;
|
||||
@ -1529,6 +1764,221 @@ Status BuildHostGraphForFuncCallNode(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name, const string& xla_cluster_name,
|
||||
const std::map<string, int>& host_compute_core, Graph* g, Node* n,
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
|
||||
std::vector<string>* host_graphs,
|
||||
std::vector<string>* shape_inference_graphs,
|
||||
bool* has_outside_compilation) {
|
||||
bool func_has_outside_compilation = false;
|
||||
NameAttrList func;
|
||||
if (fld->Contains(n->type_string())) {
|
||||
func.set_name(n->type_string());
|
||||
typedef protobuf::Map<string, AttrValue> AttrMap;
|
||||
*func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
|
||||
} else if (n->IsPartitionedCall()) {
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func));
|
||||
} else {
|
||||
TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp);
|
||||
func.set_name(FunctionLibraryDefinition::kGradientOp);
|
||||
*func.mutable_attr() = n->def().attr();
|
||||
}
|
||||
string new_func_name = absl::StrCat(n->name(), "_oc");
|
||||
string host_func_name = absl::StrCat("oc_func_call_host_", n->name());
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
func, new_func_name, host_func_name, host_compute_core, flr, fld,
|
||||
shape_inference_graphs, &func_has_outside_compilation));
|
||||
|
||||
// If the function call does not have outside compilation, nothing to do.
|
||||
if (!func_has_outside_compilation) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
*has_outside_compilation = true;
|
||||
|
||||
// Change `n` to call the new function directly.
|
||||
auto replace_builder =
|
||||
absl::make_unique<NodeDefBuilder>(n->name(), new_func_name, fld);
|
||||
std::vector<NodeDefBuilder::NodeOut> inputs(n->num_inputs());
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(e->dst_input() >= 0 && e->dst_input() < inputs.size());
|
||||
inputs[e->dst_input()] =
|
||||
NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
|
||||
e->src()->output_type(e->src_output())};
|
||||
}
|
||||
for (const auto& input : inputs) {
|
||||
replace_builder->Input(input);
|
||||
}
|
||||
for (const auto& attr : n->attrs()) {
|
||||
replace_builder->Attr(attr.first, attr.second);
|
||||
}
|
||||
auto replace_def = absl::make_unique<NodeDef>();
|
||||
TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get()));
|
||||
TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def));
|
||||
replace->AddAttr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{kXlaTokenArgNodeName});
|
||||
|
||||
// Build host side graph for the function call.
|
||||
string oc_host_graph_name =
|
||||
absl::StrCat("oc_func_host_graph_", replace->name());
|
||||
TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
|
||||
xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
|
||||
replace->name(), host_func_name, oc_host_graph_name, fld));
|
||||
|
||||
// Record the host graph.
|
||||
host_graphs->push_back(oc_host_graph_name);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExtractOutsideCompilationForIfNode(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name, const string& xla_cluster_name,
|
||||
const std::map<string, int>& host_compute_core, Graph* g, Node* n,
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
|
||||
std::vector<string>* host_graphs,
|
||||
std::vector<string>* shape_inference_graphs,
|
||||
bool* has_outside_compilation) {
|
||||
// Instantiate "then_branch" and "else_branch".
|
||||
NameAttrList then_branch, else_branch;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
|
||||
|
||||
// Extract outside compilation for then_branch and else_branch.
|
||||
bool then_branch_has_outside_compilation = false;
|
||||
bool else_branch_has_outside_compilation = false;
|
||||
string then_branch_host_func_name =
|
||||
absl::StrCat("oc_then_branch_host_if_", n->name()),
|
||||
else_branch_host_func_name =
|
||||
absl::StrCat("oc_else_branch_host_if_", n->name());
|
||||
string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
|
||||
else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
then_branch, then_branch_xla_func_name, then_branch_host_func_name,
|
||||
host_compute_core, flr, fld, shape_inference_graphs,
|
||||
&then_branch_has_outside_compilation));
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
else_branch, else_branch_xla_func_name, else_branch_host_func_name,
|
||||
host_compute_core, flr, fld, shape_inference_graphs,
|
||||
&else_branch_has_outside_compilation));
|
||||
|
||||
// If then/else branch do not have outside compilation, nothing to do.
|
||||
if (!then_branch_has_outside_compilation &&
|
||||
!else_branch_has_outside_compilation) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
*has_outside_compilation = true;
|
||||
|
||||
// Change If node to call the new functions.
|
||||
then_branch.set_name(then_branch_xla_func_name);
|
||||
n->ClearAttr("then_branch");
|
||||
n->AddAttr("then_branch", then_branch);
|
||||
else_branch.set_name(else_branch_xla_func_name);
|
||||
n->ClearAttr("else_branch");
|
||||
n->AddAttr("else_branch", else_branch);
|
||||
|
||||
string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
|
||||
|
||||
// XLA computation: add a SendToHost node to send cond predicate.
|
||||
Node* pred_node;
|
||||
TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Node * send_pred_node,
|
||||
BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
|
||||
host_transfer_key, pred_node, g));
|
||||
n->AddAttr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{send_pred_node->name()});
|
||||
|
||||
// Add a control edge from `send_pred_node` to If node, so XlaCompiler will
|
||||
// visit If node after `send_pred_node`, thus the token output for
|
||||
// `send_pred_node` has been generated.
|
||||
g->AddControlEdge(send_pred_node, n);
|
||||
|
||||
// Build host side graph for the "If" node.
|
||||
string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
|
||||
TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
n->name(), host_transfer_key, oc_host_graph_name, fld,
|
||||
then_branch_host_func_name, else_branch_host_func_name));
|
||||
host_graphs->push_back(oc_host_graph_name);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExtractOutsideCompilationForWhileNode(
|
||||
const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name, const string& xla_cluster_name,
|
||||
const std::map<string, int>& host_compute_core, Graph* g, Node* n,
|
||||
FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
|
||||
std::vector<string>* host_graphs,
|
||||
std::vector<string>* shape_inference_graphs,
|
||||
bool* has_outside_compilation) {
|
||||
// Instantiate "cond" and "body".
|
||||
NameAttrList cond, body;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
|
||||
|
||||
// Extract outside compilation for cond and body.
|
||||
bool cond_has_outside_compilation = false;
|
||||
bool body_has_outside_compilation = false;
|
||||
string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()),
|
||||
body_host_func_name = absl::StrCat("oc_body_host_while_", n->name());
|
||||
string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
|
||||
body_xla_func_name = absl::StrCat(body.name(), "_oc");
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
|
||||
fld, shape_inference_graphs, &cond_has_outside_compilation));
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
|
||||
fld, shape_inference_graphs, &body_has_outside_compilation));
|
||||
|
||||
// If cond/body do not have outside compilation, nothing to do.
|
||||
if (!cond_has_outside_compilation && !body_has_outside_compilation) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
*has_outside_compilation = true;
|
||||
|
||||
// Change While node to call the new functions.
|
||||
cond.set_name(cond_xla_func_name);
|
||||
n->ClearAttr("cond");
|
||||
n->AddAttr("cond", cond);
|
||||
body.set_name(body_xla_func_name);
|
||||
n->ClearAttr("body");
|
||||
n->AddAttr("body", body);
|
||||
|
||||
string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
|
||||
|
||||
// XLA computation: rewrite cond function to add a SendToHost node to send
|
||||
// loop predicate.
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key));
|
||||
n->AddAttr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{kXlaTokenArgNodeName});
|
||||
|
||||
// Build host side graph for the "While" node.
|
||||
string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
|
||||
TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
n->name(), host_transfer_key, oc_host_graph_name, fld,
|
||||
cond_host_func_name, body_host_func_name));
|
||||
host_graphs->push_back(oc_host_graph_name);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
|
||||
Graph* g, const string& xla_cluster_attr_name,
|
||||
const string& outside_compilation_attr_name, const string& xla_cluster_name,
|
||||
@ -1540,193 +1990,32 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
|
||||
for (Node* n : g->nodes()) {
|
||||
if (n->IsIfNode()) {
|
||||
if_nodes.push_back(n);
|
||||
} else if (n->type_string() == "While") {
|
||||
} else if (n->IsWhileNode()) {
|
||||
while_nodes.push_back(n);
|
||||
} else if (fld->Contains(n->type_string())) {
|
||||
} else if (IsFunctionCall(*fld, *n)) {
|
||||
func_call_nodes.push_back(n);
|
||||
} else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) {
|
||||
// Only gradient for user-defined function should be considered as
|
||||
// function call node.
|
||||
NameAttrList original_func;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(
|
||||
n->def(), FunctionLibraryDefinition::kFuncAttr, &original_func));
|
||||
if (fld->Contains(original_func.name())) {
|
||||
func_call_nodes.push_back(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* n : func_call_nodes) {
|
||||
// Extract outside compilation for the function call.
|
||||
bool func_has_outside_compilation = false;
|
||||
NameAttrList func;
|
||||
func.set_name(n->type_string());
|
||||
typedef protobuf::Map<string, AttrValue> AttrMap;
|
||||
*func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
|
||||
string new_func_name = absl::StrCat(n->name(), "_oc");
|
||||
string host_func_name = absl::StrCat("oc_func_call_host_", n->name());
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
func, new_func_name, host_func_name, host_compute_core, flr, fld,
|
||||
shape_inference_graphs, &func_has_outside_compilation));
|
||||
|
||||
// If the function call does not have outside compilation, nothing to do.
|
||||
if (!func_has_outside_compilation) {
|
||||
continue;
|
||||
}
|
||||
|
||||
*has_outside_compilation = true;
|
||||
|
||||
// Change `n` to call the new function directly.
|
||||
NodeDefBuilder replace_builder(n->name(), new_func_name, fld);
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
replace_builder.Input(e->src()->name(), e->src_output(),
|
||||
e->src()->output_type(e->src_output()));
|
||||
}
|
||||
for (const auto& attr : n->attrs()) {
|
||||
replace_builder.Attr(attr.first, attr.second);
|
||||
}
|
||||
NodeDef replace_def;
|
||||
TF_RETURN_IF_ERROR(replace_builder.Finalize(&replace_def));
|
||||
TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, replace_def));
|
||||
replace->AddAttr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{kXlaTokenArgNodeName});
|
||||
|
||||
// Build host side graph for the function call.
|
||||
string oc_host_graph_name =
|
||||
absl::StrCat("oc_func_host_graph_", replace->name());
|
||||
TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
|
||||
xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
|
||||
replace->name(), host_func_name, oc_host_graph_name, fld));
|
||||
|
||||
// Record the host graph.
|
||||
host_graphs->push_back(oc_host_graph_name);
|
||||
host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
|
||||
has_outside_compilation));
|
||||
}
|
||||
|
||||
for (Node* n : if_nodes) {
|
||||
// Instantiate "then_branch" and "else_branch".
|
||||
NameAttrList then_branch, else_branch;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
|
||||
|
||||
// Extract outside compilation for then_branch and else_branch.
|
||||
bool then_branch_has_outside_compilation = false;
|
||||
bool else_branch_has_outside_compilation = false;
|
||||
string then_branch_host_func_name =
|
||||
absl::StrCat("oc_then_branch_host_if_", n->name()),
|
||||
else_branch_host_func_name =
|
||||
absl::StrCat("oc_else_branch_host_if_", n->name());
|
||||
string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
|
||||
else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
then_branch, then_branch_xla_func_name, then_branch_host_func_name,
|
||||
host_compute_core, flr, fld, shape_inference_graphs,
|
||||
&then_branch_has_outside_compilation));
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
else_branch, else_branch_xla_func_name, else_branch_host_func_name,
|
||||
host_compute_core, flr, fld, shape_inference_graphs,
|
||||
&else_branch_has_outside_compilation));
|
||||
|
||||
// If then/else branch do not have outside compilation, nothing to do.
|
||||
if (!then_branch_has_outside_compilation &&
|
||||
!else_branch_has_outside_compilation) {
|
||||
continue;
|
||||
}
|
||||
|
||||
*has_outside_compilation = true;
|
||||
|
||||
// Change If node to call the new functions.
|
||||
then_branch.set_name(then_branch_xla_func_name);
|
||||
n->ClearAttr("then_branch");
|
||||
n->AddAttr("then_branch", then_branch);
|
||||
else_branch.set_name(else_branch_xla_func_name);
|
||||
n->ClearAttr("else_branch");
|
||||
n->AddAttr("else_branch", else_branch);
|
||||
|
||||
string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
|
||||
|
||||
// XLA computation: add a SendToHost node to send cond predicate.
|
||||
Node* pred_node;
|
||||
TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Node * send_pred_node,
|
||||
BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
|
||||
host_transfer_key, pred_node, g));
|
||||
n->AddAttr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{send_pred_node->name()});
|
||||
|
||||
// Add a control edge from `send_pred_node` to If node, so XlaCompiler will
|
||||
// visit If node after `send_pred_node`, thus the token output for
|
||||
// `send_pred_node` has been generated.
|
||||
g->AddControlEdge(send_pred_node, n);
|
||||
|
||||
// Build host side graph for the "If" node.
|
||||
string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
|
||||
TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
n->name(), host_transfer_key, oc_host_graph_name, fld,
|
||||
then_branch_host_func_name, else_branch_host_func_name));
|
||||
host_graphs->push_back(oc_host_graph_name);
|
||||
host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
|
||||
has_outside_compilation));
|
||||
}
|
||||
|
||||
for (Node* n : while_nodes) {
|
||||
// Instantiate "cond" and "body".
|
||||
NameAttrList cond, body;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
|
||||
|
||||
// Extract outside compilation for cond and body.
|
||||
bool cond_has_outside_compilation = false;
|
||||
bool body_has_outside_compilation = false;
|
||||
string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()),
|
||||
body_host_func_name = absl::StrCat("oc_body_host_while_", n->name());
|
||||
string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
|
||||
body_xla_func_name = absl::StrCat(body.name(), "_oc");
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
|
||||
fld, shape_inference_graphs, &cond_has_outside_compilation));
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
|
||||
fld, shape_inference_graphs, &body_has_outside_compilation));
|
||||
|
||||
// If cond/body do not have outside compilation, nothing to do.
|
||||
if (!cond_has_outside_compilation && !body_has_outside_compilation) {
|
||||
continue;
|
||||
}
|
||||
|
||||
*has_outside_compilation = true;
|
||||
|
||||
// Change While node to call the new functions.
|
||||
cond.set_name(cond_xla_func_name);
|
||||
n->ClearAttr("cond");
|
||||
n->AddAttr("cond", cond);
|
||||
body.set_name(body_xla_func_name);
|
||||
n->ClearAttr("body");
|
||||
n->AddAttr("body", body);
|
||||
|
||||
string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
|
||||
|
||||
// XLA computation: rewrite cond function to add a SendToHost node to send
|
||||
// loop predicate.
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key));
|
||||
n->AddAttr(kXlaTokenInputNodesAttrName,
|
||||
std::vector<string>{kXlaTokenArgNodeName});
|
||||
|
||||
// Build host side graph for the "While" node.
|
||||
string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
|
||||
TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
n->name(), host_transfer_key, oc_host_graph_name, fld,
|
||||
cond_host_func_name, body_host_func_name));
|
||||
host_graphs->push_back(oc_host_graph_name);
|
||||
host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
|
||||
has_outside_compilation));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -1889,11 +2178,11 @@ Status ExtractOutsideCompilationForFunction(
|
||||
|
||||
// Encapsulate outside_compilation cluster into function call node.
|
||||
std::unique_ptr<Graph> graph_out;
|
||||
RewriteOutsideCompilationSubgraphFn rewrite_fn(
|
||||
auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
new_func_name);
|
||||
TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
|
||||
outside_compilation_attr_name, *fbody->graph, rewrite_fn,
|
||||
outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
|
||||
/*reuse_existing_functions=*/true, &graph_out, fld));
|
||||
|
||||
// Replace outside_compilation function nodes with HostCompute ops.
|
||||
@ -1908,26 +2197,26 @@ Status ExtractOutsideCompilationForFunction(
|
||||
// If we could not infer shapes for XlaSendFromHost inputs statically, we
|
||||
// will set the "shape_inference_graph" attribute. In that case, copy
|
||||
// outside compilation subgraph as shape inference graph in `fld`.
|
||||
NameAttrList shape_inference_graph;
|
||||
auto shape_inference_graph = absl::make_unique<NameAttrList>();
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
|
||||
&shape_inference_graph));
|
||||
if (!shape_inference_graph.name().empty()) {
|
||||
shape_inference_graphs->push_back(shape_inference_graph.name());
|
||||
shape_inference_graph.get()));
|
||||
if (!shape_inference_graph->name().empty()) {
|
||||
shape_inference_graphs->push_back(shape_inference_graph->name());
|
||||
shape_inference_graphs_to_rewrite.push_back(
|
||||
shape_inference_graph.name());
|
||||
shape_inference_graph->name());
|
||||
|
||||
const FunctionDef* xla_fdef = fld->Find(n->name());
|
||||
if (!xla_fdef) {
|
||||
return errors::Internal("Cannot find XLA function ", n->name());
|
||||
}
|
||||
FunctionDef shape_inference_fdef = *xla_fdef;
|
||||
shape_inference_fdef.mutable_signature()->set_name(
|
||||
shape_inference_graph.name());
|
||||
if (fld->Find(shape_inference_graph.name())) {
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(),
|
||||
shape_inference_fdef));
|
||||
auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
|
||||
shape_inference_fdef->mutable_signature()->set_name(
|
||||
shape_inference_graph->name());
|
||||
if (fld->Find(shape_inference_graph->name())) {
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph->name(),
|
||||
*shape_inference_fdef));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1972,15 +2261,15 @@ Status ExtractOutsideCompilationForFunction(
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
|
||||
outside_compilation_host_graphs, fld, &host_graph));
|
||||
FunctionDef host_graph_fdef;
|
||||
auto host_graph_fdef = absl::make_unique<FunctionDef>();
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
|
||||
HostGraphControlRetMapping,
|
||||
&host_graph_fdef));
|
||||
host_graph_fdef.get()));
|
||||
if (fld->Find(host_graph_func_name)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
fld->ReplaceFunction(host_graph_func_name, host_graph_fdef));
|
||||
fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
|
||||
}
|
||||
|
||||
// Shape inference graphs might contain Placeholder nodes for outside
|
||||
@ -1999,19 +2288,19 @@ Status ExtractOutsideCompilationForFunction(
|
||||
}
|
||||
|
||||
// Replace original function.
|
||||
FunctionDef updated_fdef;
|
||||
auto updated_fdef = absl::make_unique<FunctionDef>();
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef));
|
||||
GraphToFunctionDef(*graph_out, new_func_name, updated_fdef.get()));
|
||||
const FunctionDef* original_fdef = fld->Find(func_name);
|
||||
if (original_fdef) {
|
||||
for (const auto& attr : original_fdef->attr()) {
|
||||
(*updated_fdef.mutable_attr())[attr.first] = attr.second;
|
||||
(*updated_fdef->mutable_attr())[attr.first] = attr.second;
|
||||
}
|
||||
}
|
||||
if (fld->Find(new_func_name)) {
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef));
|
||||
TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
|
||||
}
|
||||
if (VLOG_IS_ON(4)) {
|
||||
DumpGraphToFile(
|
||||
|
@ -105,6 +105,8 @@ void AllocateAndParseFlags() {
|
||||
build_ops_flags = new BuildXlaOpsPassFlags;
|
||||
build_ops_flags->tf_xla_enable_lazy_compilation = true;
|
||||
build_ops_flags->tf_xla_print_cluster_outputs = false;
|
||||
build_ops_flags->tf_xla_check_cluster_input_numerics = false;
|
||||
build_ops_flags->tf_xla_check_cluster_output_numerics = false;
|
||||
build_ops_flags->tf_xla_disable_constant_folding = false;
|
||||
|
||||
mark_for_compilation_flags = new MarkForCompilationPassFlags;
|
||||
@ -144,6 +146,14 @@ void AllocateAndParseFlags() {
|
||||
&build_ops_flags->tf_xla_print_cluster_outputs,
|
||||
"If true then insert Print nodes to print out values produced by "
|
||||
"XLA clusters."),
|
||||
Flag("tf_xla_check_cluster_input_numerics",
|
||||
&build_ops_flags->tf_xla_check_cluster_input_numerics,
|
||||
"If true then insert CheckNumerics nodes to to check all cluster "
|
||||
"inputs."),
|
||||
Flag("tf_xla_check_cluster_output_numerics",
|
||||
&build_ops_flags->tf_xla_check_cluster_output_numerics,
|
||||
"If true then insert CheckNumerics nodes to to check all cluster "
|
||||
"outputs."),
|
||||
|
||||
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
|
@ -103,6 +103,14 @@ struct BuildXlaOpsPassFlags {
|
||||
// clusters. Useful for debugging.
|
||||
bool tf_xla_print_cluster_outputs;
|
||||
|
||||
// If true, insert CheckNumerics nodes for every floating point typed input to
|
||||
// an XLA cluster.
|
||||
bool tf_xla_check_cluster_input_numerics;
|
||||
|
||||
// If true, insert CheckNumerics nodes for every floating point typed output
|
||||
// from an XLA cluster.
|
||||
bool tf_xla_check_cluster_output_numerics;
|
||||
|
||||
// Disables all constant folding. The primary use for this is for testing to
|
||||
// guarantee that tests are run on XLA and not on TF's CPU implementation.
|
||||
bool tf_xla_disable_constant_folding;
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
|
||||
#include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
|
||||
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
|
||||
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
|
||||
@ -50,6 +51,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25,
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 5,
|
||||
CloneConstantsForBetterClusteringPass);
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 9,
|
||||
ClusterScopingPass);
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
|
||||
MarkForCompilationPass);
|
||||
|
||||
|
@ -5,33 +5,48 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
XLA_OPS_DEPS = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:xla_activity_listener",
|
||||
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/jit:xla_compilation_cache",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||
]
|
||||
|
||||
# Linked by tensorflow core, without registration of jit compilation passes.
|
||||
cc_library(
|
||||
name = "xla_ops",
|
||||
name = "xla_ops_no_jit_rewrite_registration",
|
||||
srcs = ["xla_ops.cc"],
|
||||
hdrs = ["xla_ops.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:xla_compilation_cache",
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
"//tensorflow/compiler/jit:xla_launch_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:state_ops_op_lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
deps = XLA_OPS_DEPS,
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_ops",
|
||||
hdrs = ["xla_ops.h"],
|
||||
deps = XLA_OPS_DEPS + [
|
||||
":xla_ops_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/jit:jit_compilation_passes",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -62,8 +63,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
DeviceType device_type = ctx->device_type();
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator;
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
se::DeviceMemoryAllocator* custom_allocator = nullptr;
|
||||
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
@ -83,23 +83,13 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
|
||||
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
|
||||
// allocator to allocate real buffers.
|
||||
|
||||
platform_id = xla_device_metadata->platform()->id();
|
||||
device_allocator =
|
||||
custom_allocator =
|
||||
xla_device_metadata->client()->backend().memory_allocator();
|
||||
}
|
||||
|
||||
if (!device_allocator) {
|
||||
xla::StatusOr<se::Platform*> maybe_platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_id);
|
||||
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
|
||||
|
||||
xla_allocator = absl::make_unique<se::TfAllocatorAdapter>(
|
||||
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
|
||||
}
|
||||
|
||||
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
std::move(xla_allocator), device_allocator);
|
||||
custom_allocator);
|
||||
}
|
||||
|
||||
// A closure describing how to run a compiled version of a TensorFlow function.
|
||||
@ -184,6 +174,31 @@ class XlaExecutableClosureStore {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
|
||||
};
|
||||
|
||||
// Return allocator from platform info if non-null, or populate and return a
|
||||
// pointer to the allocator adapter with allocator from context.
|
||||
//
|
||||
// This is necessary because for XLA devices the underlying TF allocator returns
|
||||
// dummy tensors.
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
|
||||
if (platform_info.custom_allocator()) {
|
||||
return platform_info.custom_allocator();
|
||||
}
|
||||
if (!ctx->op_device_context()) {
|
||||
// Stream is not set for the host platform.
|
||||
se::Platform* platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
|
||||
.ValueOrDie();
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
// platform_info.
|
||||
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
|
||||
ctx->op_device_context()->stream());
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
@ -280,6 +295,7 @@ static Status CompileToLocalExecutable(
|
||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
|
||||
*client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options;
|
||||
options.client = *client;
|
||||
if (ctx->op_device_context() != nullptr) {
|
||||
@ -291,7 +307,8 @@ static Status CompileToLocalExecutable(
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
options.allow_cpu_custom_calls =
|
||||
(platform_info.platform_id() == se::host::kHostPlatformId);
|
||||
options.device_allocator = platform_info.allocator();
|
||||
options.device_allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info);
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
options.shape_representation_fn =
|
||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||
@ -349,8 +366,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, platform_info_.allocator(),
|
||||
client, allocator,
|
||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||
platform_info_.UseMultipleStreams());
|
||||
launch_context.PopulateInputs(ctx, kernel, variables,
|
||||
@ -360,21 +380,28 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
VLOG(2) << "Executing computation.";
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(platform_info_.allocator());
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
auto run_result = executable->Run(launch_context.arguments(), run_options);
|
||||
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
|
||||
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
|
||||
run_result = executable->Run(launch_context.arguments(), run_options);
|
||||
} else {
|
||||
run_result = executable->RunAsync(launch_context.arguments(), run_options);
|
||||
}
|
||||
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
|
||||
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time: " << elapsed << "us";
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
|
||||
ctx, kernel, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0));
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias));
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
|
||||
@ -467,6 +494,10 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
if (status.code() == error::UNIMPLEMENTED) {
|
||||
LOG(WARNING) << "Compilation failed:" << status.ToString()
|
||||
<< ". Falling back to TF function call.";
|
||||
|
||||
BroadcastOptimizationRemark(
|
||||
XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
|
||||
.IgnoreError();
|
||||
executable = nullptr;
|
||||
mutex_lock guard(cannot_compile_cluster_mu_);
|
||||
cannot_compile_cluster_ = true;
|
||||
@ -498,7 +529,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
client, executable, kernel, std::move(variables), constants_.size()));
|
||||
|
||||
Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
|
||||
compilation_key.flat<string>()(0) = key;
|
||||
compilation_key.flat<tstring>()(0) = key;
|
||||
|
||||
Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
|
||||
compilation_successful.flat<bool>()(0) = true;
|
||||
@ -513,13 +544,16 @@ XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
|
||||
void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(3) << "XlaRunOp " << def().name();
|
||||
Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
|
||||
const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
|
||||
const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<tstring>()(0);
|
||||
|
||||
XlaExecutableClosure closure =
|
||||
XlaExecutableClosureStore::Global()->Consume(key);
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
XlaComputationLaunchContext launch_context(
|
||||
closure.client(), platform_info_.allocator(),
|
||||
closure.client(), allocator,
|
||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||
/*use_multiple_streams=*/platform_info_.UseMultipleStreams());
|
||||
|
||||
@ -544,19 +578,28 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(platform_info_.allocator());
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
||||
auto run_result =
|
||||
closure.executable()->Run(launch_context.arguments(), run_options);
|
||||
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
|
||||
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
|
||||
run_result =
|
||||
closure.executable()->Run(launch_context.arguments(), run_options);
|
||||
} else {
|
||||
run_result =
|
||||
closure.executable()->RunAsync(launch_context.arguments(), run_options);
|
||||
}
|
||||
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
|
||||
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
closure.executable()->executable()->module().input_output_alias_config();
|
||||
|
||||
tensorflow::profiler::TraceMe hlo_module_activity(
|
||||
[&] {
|
||||
return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
|
||||
@ -567,7 +610,8 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
ctx,
|
||||
launch_context.PopulateOutputs(
|
||||
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args()));
|
||||
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
|
||||
input_output_alias));
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
|
||||
|
@ -37,18 +37,14 @@ class XlaPlatformInfo {
|
||||
public:
|
||||
XlaPlatformInfo() : device_type_("") {}
|
||||
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
||||
explicit XlaPlatformInfo(
|
||||
const DeviceType device_type, se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator,
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
explicit XlaPlatformInfo(const DeviceType device_type,
|
||||
se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
: device_type_(device_type),
|
||||
platform_id_(platform_id),
|
||||
xla_device_metadata_(xla_device_metadata),
|
||||
xla_allocator_(std::move(xla_allocator)),
|
||||
device_allocator_(device_allocator) {
|
||||
CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
|
||||
}
|
||||
device_allocator_(device_allocator) {}
|
||||
|
||||
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
|
||||
|
||||
@ -56,9 +52,11 @@ class XlaPlatformInfo {
|
||||
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
|
||||
}
|
||||
|
||||
se::DeviceMemoryAllocator* allocator() const {
|
||||
return device_allocator_ ? device_allocator_ : xla_allocator_.get();
|
||||
// Non-null only when run on an XLA device.
|
||||
se::DeviceMemoryAllocator* custom_allocator() const {
|
||||
return device_allocator_;
|
||||
}
|
||||
|
||||
DeviceType device_type() const { return device_type_; }
|
||||
|
||||
// This is equal to xla_device_metadata()->platform()->id() if
|
||||
@ -82,11 +80,8 @@ class XlaPlatformInfo {
|
||||
const XlaDevice::Metadata* xla_device_metadata_;
|
||||
|
||||
// If the op associated with this XlaPlatformInfo is placed on an XLA device
|
||||
// then device_allocator_ is the xla::Backend's memory allocator and
|
||||
// xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
|
||||
// then device_allocator_ is null and xla_allocator_ points to an appropriate
|
||||
// se::TfAllocatorAdapter instance.
|
||||
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator_;
|
||||
// then device_allocator_ is the xla::Backend's memory allocator. If the op
|
||||
// is placed on a regular CPU or GPU device then device_allocator_ is null.
|
||||
se::DeviceMemoryAllocator* device_allocator_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
||||
|
@ -677,8 +677,7 @@ bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
|
||||
}
|
||||
|
||||
DataType dtype;
|
||||
if (!GetNodeAttr(n->def(), "dtype", &dtype).ok() ||
|
||||
!DataTypeIsInteger(dtype)) {
|
||||
if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -695,7 +694,7 @@ bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation(
|
||||
}
|
||||
|
||||
const TensorProto* proto = nullptr;
|
||||
if (!GetNodeAttr(const_input->def(), "value", &proto).ok()) {
|
||||
if (!TryGetNodeAttr(const_input->def(), "value", &proto)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -924,20 +923,35 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency(
|
||||
}
|
||||
|
||||
absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
|
||||
// Look for an _XlaScope on both nodes. If both nodes have a scope and the
|
||||
// scopes do not match, do not cluster along this edge. This restriction is
|
||||
// overridden if the global_jit_level_ is ON. If even one of the nodes lacks
|
||||
// an _XlaScope attribute, then it is treated as a "bridge" and a cluster may
|
||||
// be created along it. We may want to restrict this behavior to require all
|
||||
// nodes marked with _XlaCompile=true to also have a _XlaScope property set
|
||||
// (and raise an error otherwise); but for now we don't do this.
|
||||
if (global_jit_level_ != OptimizerOptions::OFF) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
// Look for either _XlaScope or _XlaInternalScope on both nodes to guide
|
||||
// clustering. If both nodes have a scope and the scopes do not match, do
|
||||
// not cluster along this edge. If even one of the nodes lacks a scope
|
||||
// attribute, then it is treated as a "bridge" and a cluster may be created
|
||||
// along it.
|
||||
//
|
||||
// The difference between _XlaScope and _XlaInternalScope is that _XlaScope is
|
||||
// provided by users through jit_scope APIs, while _XlaInternalScope is
|
||||
// automatically generated by the ClusterScopingPass when auto_jit is on. As
|
||||
// such, we respect _XlaScope only when auto_jit is off, while respecting
|
||||
// _XlaInternalScope only when auto_jit is on.
|
||||
//
|
||||
// We may want to restrict the _XlaScope behavior to require all nodes marked
|
||||
// with _XlaCompile=true to also have a _XlaScope property set (and raise an
|
||||
// error otherwise); but for now we don't do this.
|
||||
|
||||
string scope;
|
||||
if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) {
|
||||
return scope;
|
||||
if (global_jit_level_ != OptimizerOptions::OFF) {
|
||||
// If global_jit_level_ is ON, respect only _XlaInternalScope.
|
||||
const string& scope =
|
||||
GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr);
|
||||
if (!scope.empty()) {
|
||||
return scope;
|
||||
}
|
||||
} else {
|
||||
// If global_jit_level_ is OFF, respect only _XlaScope.
|
||||
const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr);
|
||||
if (!scope.empty()) {
|
||||
return scope;
|
||||
}
|
||||
}
|
||||
|
||||
return absl::nullopt;
|
||||
@ -970,8 +984,7 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
int effective_cluster_size =
|
||||
(node->IsIdentity() || node->IsConstant()) ? 0 : 1;
|
||||
|
||||
bool has_functional_control_flow =
|
||||
node->type_string() == "While" || node->IsIfNode();
|
||||
bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode();
|
||||
|
||||
absl::optional<DeadnessPredicate> deadness_predicate;
|
||||
if (deadness_analysis_) {
|
||||
@ -1000,7 +1013,7 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
bool is_xla_compile_attr_true = false;
|
||||
|
||||
bool xla_compile_attr;
|
||||
if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) {
|
||||
if (TryGetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr)) {
|
||||
is_xla_compile_attr_true |= xla_compile_attr;
|
||||
}
|
||||
|
||||
@ -1549,9 +1562,7 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
|
||||
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally &&
|
||||
global_jit_level_ != OptimizerOptions::OFF);
|
||||
|
||||
if (!should_compile &&
|
||||
registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested &&
|
||||
if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
|
||||
device_type.type_string() == DEVICE_CPU) {
|
||||
static std::once_flag once;
|
||||
std::call_once(once, [] {
|
||||
@ -1628,10 +1639,9 @@ std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
bool IsCompilable(
|
||||
FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>*
|
||||
uncompilable_node_info) {
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info) {
|
||||
Device* device = flr->device();
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
|
||||
@ -1657,8 +1667,8 @@ bool IsCompilable(
|
||||
return checker.IsCompilableCall(ndef, flr);
|
||||
}
|
||||
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompilable_node_result = checker.FindUncompilableNodes(ndef, flr);
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
|
||||
checker.FindUncompilableNodes(ndef, flr);
|
||||
uncompilable_node_info->swap(uncompilable_node_result);
|
||||
return uncompilable_node_info->empty();
|
||||
}
|
||||
|
@ -52,10 +52,9 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
// function is compilable iff every operator in the function body is
|
||||
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
|
||||
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
|
||||
bool IsCompilable(
|
||||
FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>*
|
||||
uncompilable_node_info = nullptr);
|
||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info = nullptr);
|
||||
|
||||
namespace testing {
|
||||
// DO NOT USE IN PRODUCTION.
|
||||
|
@ -52,7 +52,7 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
|
||||
std::unordered_map<string, string> ids;
|
||||
for (Node* node : graph.nodes()) {
|
||||
string cluster;
|
||||
if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
|
||||
if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) {
|
||||
CHECK(!cluster.empty());
|
||||
ids[node->name()] = cluster;
|
||||
}
|
||||
@ -1718,5 +1718,91 @@ TEST(XlaCompilationTest, UnsupportedEnterExitPattern) {
|
||||
EXPECT_EQ(0, clusters.size());
|
||||
}
|
||||
|
||||
namespace {
|
||||
Node* MakeStageNode(GraphDefBuilder& builder, string name,
|
||||
std::initializer_list<DataType> dtypes,
|
||||
absl::Span<const ops::NodeOut> values) {
|
||||
auto opts = builder.opts()
|
||||
.WithName(std::move(name))
|
||||
.WithAttr("dtypes", std::move(dtypes));
|
||||
if (opts.HaveError()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
NodeBuilder node_builder(name, "Stage", opts.op_registry());
|
||||
node_builder.Input(values);
|
||||
return opts.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
|
||||
auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status {
|
||||
// Construct a graph as below with two pipeline stages and test that nodes
|
||||
// in different stages will not be merged if ClusterScopingPass is on.
|
||||
//
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// a -> add0 -> relu0 -> stage
|
||||
//
|
||||
// b
|
||||
// |
|
||||
// v
|
||||
// unstage -> add1 -> relu1
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("a")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("b")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* unstage = ops::SourceOp(
|
||||
"Unstage",
|
||||
builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
|
||||
|
||||
Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
|
||||
Node* add1 =
|
||||
ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
|
||||
Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
|
||||
ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
|
||||
MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0});
|
||||
|
||||
return GraphDefBuilderToGraph(builder, graph->get());
|
||||
};
|
||||
|
||||
// All nodes go into the same cluster if ClusterScopingPass is off.
|
||||
{
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(build_staged_graph(&graph));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
&graph,
|
||||
MarkForCompilationPassTestHelper::Options().WithNoClusterScoping()));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_EQ(clusters["add0"], clusters["add1"]);
|
||||
EXPECT_EQ(clusters["add0"], clusters["relu1"]);
|
||||
EXPECT_EQ(clusters["relu0"], clusters["add1"]);
|
||||
EXPECT_EQ(clusters["relu0"], clusters["relu1"]);
|
||||
}
|
||||
|
||||
// By default, ClusterScopingPass is on and different pipeline stages should
|
||||
// not be merged.
|
||||
{
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(build_staged_graph(&graph));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
EXPECT_NE(clusters["add0"], clusters["add1"]);
|
||||
EXPECT_NE(clusters["add0"], clusters["relu1"]);
|
||||
EXPECT_NE(clusters["relu0"], clusters["add1"]);
|
||||
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
@ -48,8 +50,14 @@ namespace tensorflow {
|
||||
opt_options.graph = graph;
|
||||
opt_options.session_options = &session_options;
|
||||
opt_options.flib_def = flib_def;
|
||||
MarkForCompilationPass pass;
|
||||
return pass.RunForTest(
|
||||
|
||||
if (options.enable_cluster_scoping) {
|
||||
ClusterScopingPass cluster_scoping_pass;
|
||||
TF_RETURN_IF_ERROR(cluster_scoping_pass.Run(opt_options));
|
||||
}
|
||||
|
||||
MarkForCompilationPass mark_for_compilation_pass;
|
||||
return mark_for_compilation_pass.RunForTest(
|
||||
opt_options,
|
||||
/*disable_deadness_analysis=*/options.disable_deadness_analysis);
|
||||
}
|
||||
|
@ -24,8 +24,12 @@ class MarkForCompilationPassTestHelper {
|
||||
struct Options {
|
||||
bool enable_global_jit;
|
||||
bool disable_deadness_analysis;
|
||||
bool enable_cluster_scoping;
|
||||
|
||||
Options() : enable_global_jit(true), disable_deadness_analysis(true) {}
|
||||
Options()
|
||||
: enable_global_jit(true),
|
||||
disable_deadness_analysis(true),
|
||||
enable_cluster_scoping(true) {}
|
||||
|
||||
Options WithNoGlobalJit() {
|
||||
Options copy = *this;
|
||||
@ -38,6 +42,12 @@ class MarkForCompilationPassTestHelper {
|
||||
copy.disable_deadness_analysis = false;
|
||||
return copy;
|
||||
}
|
||||
|
||||
Options WithNoClusterScoping() {
|
||||
Options copy = *this;
|
||||
copy.enable_cluster_scoping = false;
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
|
||||
// Runs the MarkForCompilation pass on `graph` after assigning all nodes in
|
||||
|
@ -135,7 +135,7 @@ struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
|
||||
|
||||
if (constant_value) {
|
||||
const TensorProto* proto = nullptr;
|
||||
if (!GetNodeAttr(node->def(), "value", &proto).ok()) {
|
||||
if (!TryGetNodeAttr(node->def(), "value", &proto)) {
|
||||
if (listener->IsInterested()) {
|
||||
*listener << "\ncould not find \"value\" attribute in node";
|
||||
}
|
||||
|
@ -45,8 +45,8 @@ class AutoClusteringTestImpl : public AutoClusteringTest {
|
||||
TEST_F(AutoClusteringTestImpl, KerasImagenetMain) {
|
||||
// Generated from
|
||||
//
|
||||
// bazel run -c opt --config=cuda \
|
||||
// tensorflow_models/official/resnet/keras:keras_imagenet_main \
|
||||
// TARGET_PATH=tensorflow_models/official/vision/image_classification \
|
||||
// bazel run -c opt --config=cuda ${TARGET_PATH}:resnet_imagenet_main \
|
||||
// -- --skip_eval --num_gpus=1 --dtype=fp16 --batch_size=192 \
|
||||
// --train_steps=210 --enable_xla --enable_eager=true
|
||||
//
|
||||
@ -57,8 +57,8 @@ TEST_F(AutoClusteringTestImpl, KerasImagenetMain) {
|
||||
TEST_F(AutoClusteringTestImpl, KerasImagenetMainGraphMode) {
|
||||
// Generated from
|
||||
//
|
||||
// bazel run -c opt --config=cuda \
|
||||
// tensorflow_models/official/resnet/keras:keras_imagenet_main \
|
||||
// TARGET_PATH=tensorflow_models/official/vision/image_classification \
|
||||
// bazel run -c opt --config=cuda ${TARGET_PATH}:resnet_imagenet_main \
|
||||
// -- --use_synthetic_data --num_gpus=1 --batch_size=117 --train_steps=600 \
|
||||
// --skip_eval=True --logtostderr --enable_xla
|
||||
TF_ASSERT_OK(
|
||||
|
@ -186,7 +186,7 @@ Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt(
|
||||
/*input_buffer_bytes=*/k_buffer_size,
|
||||
/*output_buffer_bytes=*/k_buffer_size,
|
||||
io::ZlibCompressionOptions::GZIP());
|
||||
string decompressed_pbtxt_string;
|
||||
tstring decompressed_pbtxt_string;
|
||||
Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string);
|
||||
if (!s.ok() && !errors::IsOutOfRange(s)) {
|
||||
// OutOfRange is fine since we set the number of read bytes to INT_MAX.
|
||||
|
@ -94,3 +94,27 @@ message XlaJitCompilationActivity {
|
||||
// Total microseconds spent in (re-)compiling this cluster so far.
|
||||
int64 cumulative_compile_time_us = 4;
|
||||
}
|
||||
|
||||
// LINT.IfChange
|
||||
//
|
||||
// Used for logging situations seen in Tensorflow models being optimized that
|
||||
// are known to not perform well with XLA.
|
||||
//
|
||||
// Next ID: 3
|
||||
message XlaOptimizationRemark {
|
||||
// Next ID: 6
|
||||
enum Warning {
|
||||
NONE = 0;
|
||||
INACCURATE_OPERATION = 1;
|
||||
SLOW_OPERATION = 2;
|
||||
UNIMPLEMENTED_OPERATION = 3;
|
||||
SLOW_IMAGE_RESIZE_DIMENSIONS = 4;
|
||||
MEGAMORPHIC_FUNCTION = 5;
|
||||
}
|
||||
|
||||
Warning warning = 1;
|
||||
|
||||
// Information such as which node was the problem.
|
||||
string debug_information = 2;
|
||||
}
|
||||
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/compiler/jit/xla_activity_listener.h)
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -71,6 +72,21 @@ Status BroadcastXlaActivity(
|
||||
});
|
||||
}
|
||||
|
||||
Status BroadcastOptimizationRemark(XlaOptimizationRemark optimization_remark) {
|
||||
VLOG(2) << "OptimizationRemark: " << optimization_remark.DebugString();
|
||||
return ForEachListener([&](XlaActivityListener* listener) {
|
||||
return listener->Listen(optimization_remark);
|
||||
});
|
||||
}
|
||||
|
||||
Status BroadcastOptimizationRemark(
|
||||
XlaOptimizationRemark::Warning optimization_warning,
|
||||
string debug_information) {
|
||||
XlaOptimizationRemark remark;
|
||||
remark.set_warning(optimization_warning);
|
||||
remark.set_debug_information(std::move(debug_information));
|
||||
return BroadcastOptimizationRemark(std::move(remark));
|
||||
}
|
||||
void RegisterXlaActivityListener(
|
||||
std::unique_ptr<XlaActivityListener> listener) {
|
||||
XlaActivityListenerList* listener_list = GetXlaActivityListenerList();
|
||||
|
@ -27,6 +27,18 @@ Status BroadcastXlaActivity(XlaAutoClusteringActivity auto_clustering_activity);
|
||||
// Broadcast `jit_compilation_activity` to all the registered listeners.
|
||||
Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity);
|
||||
|
||||
// Broadcast `jit_compilation_activity` to all the registered listeners.
|
||||
Status BroadcastOptimizationRemark(XlaOptimizationRemark optimization_remark);
|
||||
|
||||
// LINT.IfChange
|
||||
// Called after TensorFlow realizes possible lost performance. The parameters in
|
||||
// this should match all of the values in the XlaOptimizationRemark proto.
|
||||
Status BroadcastOptimizationRemark(
|
||||
XlaOptimizationRemark::Warning optimization_warning,
|
||||
string debug_information);
|
||||
|
||||
// LINT.ThenChange(//tensorflow/compiler/jit/xla_activity.proto)
|
||||
|
||||
// Various components of the system can subclass XlaActivityListener to
|
||||
// notifications on auto-clustering and JIT compilation events.
|
||||
//
|
||||
@ -41,6 +53,9 @@ class XlaActivityListener {
|
||||
virtual Status Listen(
|
||||
const XlaJitCompilationActivity& jit_compilation_activity) = 0;
|
||||
|
||||
// Called after TensorFlow realizes possible lost performance.
|
||||
virtual Status Listen(const XlaOptimizationRemark& optimization_remark) = 0;
|
||||
|
||||
// Called at program exit in best-effort manner to give listeners a chance to
|
||||
// flush their state.
|
||||
//
|
||||
|
@ -43,6 +43,10 @@ class TestListener : public XlaActivityListener {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Listen(const XlaOptimizationRemark& optimization_remark) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
~TestListener() override {}
|
||||
|
||||
const XlaAutoClusteringActivity& auto_clustering_activity() const {
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
|
||||
@ -59,6 +60,23 @@ class XlaActivityLoggingListener final : public XlaActivityListener {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Listen(const XlaOptimizationRemark& optimization_remark) override {
|
||||
if (!IsEnabled()) {
|
||||
VLOG(3) << "Logging XlaJitCompilationActivity disabled";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (Logger* logger = Logger::GetSingletonAsync()) {
|
||||
VLOG(2) << "Logging XlaJitCompilationActivity";
|
||||
VLOG(3) << optimization_remark.DebugString();
|
||||
logger->LogProto(optimization_remark);
|
||||
} else {
|
||||
VLOG(2) << "Not logging: logger not ready yet.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsEnabled() {
|
||||
static bool result = ComputeIsEnabled();
|
||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
@ -27,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
@ -224,6 +227,20 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
out_compilation_result, out_executable);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Print something that users can search for to definitively ascertain that XLA
|
||||
// was used for their TF model.
|
||||
//
|
||||
// Prints only once to avoid spamming LOG(INFO).
|
||||
void LogOnceXlaCompiledFirstCluster() {
|
||||
static absl::once_flag log_once;
|
||||
absl::call_once(log_once, [] {
|
||||
LOG(INFO) << "Compiled cluster using XLA! This line is logged at most "
|
||||
"once for the lifetime of the process.";
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status XlaCompilationCache::CompileImpl(
|
||||
const XlaCompiler::Options& options, const NameAttrList& function,
|
||||
absl::Span<const XlaCompiler::Argument> args,
|
||||
@ -301,6 +318,9 @@ Status XlaCompilationCache::CompileImpl(
|
||||
}
|
||||
|
||||
if (is_megamorphic) {
|
||||
BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION,
|
||||
function.name())
|
||||
.IgnoreError();
|
||||
VLOG(3) << "Not compiling cluster " << function.name()
|
||||
<< " because it is megamorphic.";
|
||||
return false;
|
||||
@ -346,11 +366,13 @@ Status XlaCompilationCache::CompileImpl(
|
||||
|
||||
const uint64 compile_end_us = env->NowMicros();
|
||||
const uint64 compile_time_us = compile_end_us - compile_start_us;
|
||||
metrics::UpdateXlaCompilationTime(compile_time_us);
|
||||
{
|
||||
mutex_lock lock(cluster_compile_stats_mu_);
|
||||
auto it = cluster_compile_stats_.find(function.name());
|
||||
it->second.compile_count++;
|
||||
it->second.cumulative_compile_time_us += compile_time_us;
|
||||
LogOnceXlaCompiledFirstCluster();
|
||||
VLOG(1) << "compiled " << function.name() << " "
|
||||
<< it->second.compile_count
|
||||
<< " times, compile time: " << compile_time_us
|
||||
|
@ -83,9 +83,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
executable->Run(launch_context.arguments(), run_options);
|
||||
TF_RETURN_IF_ERROR(run_result.status());
|
||||
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias =
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
|
||||
ctx, result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0));
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -98,10 +98,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 14> kAllXlaCpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
|
||||
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
|
||||
DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 16> kAllXlaCpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
|
||||
|
@ -203,6 +203,8 @@ XlaDevice::XlaDevice(const SessionOptions& session_options,
|
||||
device_ordinal_(options.device_ordinal),
|
||||
jit_device_name_(options.compilation_device_name),
|
||||
platform_(options.platform),
|
||||
intra_op_parallelism_threads_(
|
||||
session_options.config.intra_op_parallelism_threads()),
|
||||
use_multiple_streams_(options.use_multiple_streams),
|
||||
shape_representation_fn_(options.shape_representation_fn),
|
||||
allowed_devices_(options.allowed_devices) {
|
||||
@ -233,10 +235,13 @@ xla::LocalClient* XlaDevice::client() const {
|
||||
// don't want to do it until we get a chance to hook the platform up
|
||||
// to a simulator.
|
||||
|
||||
xla::LocalClientOptions options;
|
||||
options.set_platform(platform_)
|
||||
.set_allowed_devices(allowed_devices_)
|
||||
.set_intra_op_parallelism_threads(intra_op_parallelism_threads_);
|
||||
// TODO(b/78468222): This can fail, at least when the backend is GPU and
|
||||
// there is no GPU on the host.
|
||||
return xla::ClientLibrary::GetOrCreateLocalClient(platform_, allowed_devices_)
|
||||
.ValueOrDie();
|
||||
return xla::ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie();
|
||||
}
|
||||
|
||||
Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
|
||||
|
@ -202,6 +202,8 @@ class XlaDevice : public LocalDevice {
|
||||
const DeviceType jit_device_name_;
|
||||
// The platform for this device.
|
||||
se::Platform* const platform_; // Not owned.
|
||||
// Intra-op threads to spawn (from SessionOptions).
|
||||
const int intra_op_parallelism_threads_;
|
||||
// Memory allocator associated with this device.
|
||||
Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
|
||||
|
||||
|
@ -90,8 +90,9 @@ XlaDeviceContext::XlaDeviceContext(
|
||||
CHECK(host_to_device_stream_ != nullptr);
|
||||
CHECK(stream_ != nullptr);
|
||||
if (!shape_representation_fn_) {
|
||||
shape_representation_fn_ = [](const TensorShape& shape,
|
||||
DataType dtype) -> xla::StatusOr<xla::Shape> {
|
||||
shape_representation_fn_ =
|
||||
[](const TensorShape& shape, DataType dtype,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
@ -130,9 +131,10 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
CHECK(xla_tensor);
|
||||
|
||||
Status status = [&]() -> Status {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape,
|
||||
shape_representation_fn_(device_tensor->shape(),
|
||||
device_tensor->dtype()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::Shape shape,
|
||||
shape_representation_fn_(device_tensor->shape(), device_tensor->dtype(),
|
||||
/*use_fast_memory=*/false));
|
||||
|
||||
// The device tensor should always be fresh.
|
||||
TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
|
||||
|
@ -212,11 +212,11 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<string>("T"), \
|
||||
.TypeConstraint<tstring>("T"), \
|
||||
ArgOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<string>("T") \
|
||||
.TypeConstraint<tstring>("T") \
|
||||
.HostMemory("input"), \
|
||||
RetvalOp);
|
||||
|
||||
|
@ -147,10 +147,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 14> kAllXlaGpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
|
||||
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
|
||||
DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
|
||||
|
@ -14,243 +14,20 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Utility which searches for values in a sorted list by scanning over it once.
|
||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||
// not revisited in future calls to ScanForValue, so callers must take
|
||||
// care to order their calls.
|
||||
//
|
||||
// Useful for merging multiple sorted lists in O(n) time.
|
||||
class SinglePassSearch {
|
||||
public:
|
||||
// Creates a SinglePassSearch object that can be used to search in `values`.
|
||||
// Does not take ownership of `values`. `values` must outlive this.
|
||||
// `values` must be sorted.
|
||||
explicit SinglePassSearch(const std::vector<int>* values)
|
||||
: current_index_(0), values_(values) {}
|
||||
|
||||
// Scans forward in the vector looking for "value", updating the internal
|
||||
// position in to the vector.
|
||||
// Returns true iff the vector contains the given value at or after current
|
||||
// position.
|
||||
// Not thread-safe.
|
||||
bool ScanForValue(int value) {
|
||||
while (current_index_ < values_->size() &&
|
||||
(*values_)[current_index_] <= value) {
|
||||
if ((*values_)[current_index_] == value) {
|
||||
current_index_++;
|
||||
return true;
|
||||
}
|
||||
current_index_++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int current_index_;
|
||||
const std::vector<int>* values_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) const {
|
||||
const FunctionDef* function_def =
|
||||
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
|
||||
if (function_def == nullptr) {
|
||||
// The node def is not calling a function. Individual ops can be
|
||||
// run directly using on-demand mode, no need to create XlaLaunch
|
||||
// kernel for them.
|
||||
return false;
|
||||
}
|
||||
|
||||
// If kXlaCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaCompileAttr);
|
||||
if (it != node_def.attr().end()) {
|
||||
return it->second.b();
|
||||
}
|
||||
|
||||
// kXlaCompileAttr is not set on node_def, check if it is set on
|
||||
// FunctionDef.
|
||||
bool xla_compile = false;
|
||||
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
|
||||
node_def, kXlaCompileAttr, &xla_compile);
|
||||
if (!status.ok() || !xla_compile) {
|
||||
if (VLOG_IS_ON(3)) {
|
||||
if (!status.ok()) {
|
||||
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
|
||||
<< node_def.op() << ". status=" << status.ToString();
|
||||
} else {
|
||||
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||
// runtime, returns this function's body in `fbody` as well as the indices
|
||||
// of its constant and resource arguments.
|
||||
// `fbody` is owned by `flr`.
|
||||
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
||||
// They are sorted in ascending order on this function's return.
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
const FunctionBody** fbody,
|
||||
std::vector<int>* constant_arg_indices,
|
||||
std::vector<int>* resource_arg_indices) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If node_def is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
|
||||
*fbody = flr->GetFunctionBody(handle);
|
||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||
std::vector<bool> const_args(arg_types.size());
|
||||
// If we can't analyze the const args. Bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr, flr));
|
||||
|
||||
for (int i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
constant_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// There can be hundreds of resource variables. Reserve the space for them.
|
||||
// We don't reserve for constants above as they are usually few.
|
||||
resource_arg_indices->reserve(arg_types.size());
|
||||
for (int i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i] == DT_RESOURCE) {
|
||||
resource_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
return CanCreateXlaKernel(flr, node_def);
|
||||
}
|
||||
|
||||
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) const {
|
||||
if (!CanCreateKernel(*flr, node_def)) {
|
||||
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
|
||||
}
|
||||
|
||||
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
|
||||
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompilable_node_info;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_node_info)) {
|
||||
string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
node_def.ShortDebugString(), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:\n");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
string node_message =
|
||||
absl::StrCat("\t", node_info.name, ": ",
|
||||
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
|
||||
for (const auto& stack_frame : node_info.stack_trace) {
|
||||
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
|
||||
stack_frame.name, stack_frame.function_name);
|
||||
}
|
||||
absl::StrAppend(&message, node_message);
|
||||
}
|
||||
VLOG(1) << message;
|
||||
// node_def is calling a function that XLA can't compile.
|
||||
return errors::InvalidArgument(message);
|
||||
}
|
||||
|
||||
// Get function body, constant args, and resource args.
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
// Set input and output memory types.
|
||||
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
||||
// These indices are used only for optimization purposes. They allow us
|
||||
// to loop over constant_arg_indices and resource_arg_indices only once
|
||||
// while iterating over all the function arguments checking if it is a
|
||||
// resource or a constant.
|
||||
// The reason we optimized this code is because functions can have a lot of
|
||||
// captured arguments. For example, the backward pass of ResNet50 takes in all
|
||||
// 214 variables and a similar number of activations.
|
||||
SinglePassSearch constants_search(&constant_arg_indices);
|
||||
SinglePassSearch resources_search(&resource_arg_indices);
|
||||
for (int i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||
// Compile-time constants and resource handles are expected to be in
|
||||
// host memory.
|
||||
input_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
// One might wonder, about the case where a compile-time constant argument
|
||||
// (which must be in host memory) is also used as an input into an op,
|
||||
// e.g. Add, that expects its inputs in device memory. Here is how it
|
||||
// works now.
|
||||
// First, what do we mean by "op expects an input in XYZ memory"?
|
||||
// There are two types of "ops" here: the tf2xla kernel and the HLO
|
||||
// computation it builds. The tf2xla kernel needs to retrieve the actual
|
||||
// numeric value of the compile-time constant tensors, so it really expects
|
||||
// them to be on in host memory. However, for other inputs, it refers to them
|
||||
// using xla::ComputationDataHandle, which is just a symbolic handle that
|
||||
// xla::ComputationBuilder assigns. How does this handle gets assigned for
|
||||
// constant arguments? Even constant arguments get an _Arg node in the graph
|
||||
// instatiated for Function compilation. The tf2xla kernel for constant _Arg
|
||||
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
|
||||
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
|
||||
// constant XlaLiteral is included in the HLO graph, and subsequently, in
|
||||
// the actual executable, which is copied to the device before being
|
||||
// executed. Thus, when this executable runs, the constant is available in
|
||||
// device memory.
|
||||
|
||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||
// in device memory except for resources.
|
||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||
for (int i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
||||
output_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
|
||||
// Create the kernel.
|
||||
NameAttrList function;
|
||||
function.set_name(node_def.op());
|
||||
*(function.mutable_attr()) = node_def.attr();
|
||||
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
|
||||
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
||||
return s;
|
||||
return CreateXlaKernel(flr, node_def, kernel);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -39,4 +39,4 @@ class XlaKernelCreator : public CustomKernelCreator {
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_
|
||||
|
259
tensorflow/compiler/jit/xla_kernel_creator_util.cc
Normal file
259
tensorflow/compiler/jit/xla_kernel_creator_util.cc
Normal file
@ -0,0 +1,259 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Utility which searches for values in a sorted list by scanning over it once.
|
||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||
// not revisited in future calls to ScanForValue, so callers must take
|
||||
// care to order their calls.
|
||||
//
|
||||
// Useful for merging multiple sorted lists in O(n) time.
|
||||
class SinglePassSearch {
|
||||
public:
|
||||
// Creates a SinglePassSearch object that can be used to search in `values`.
|
||||
// Does not take ownership of `values`. `values` must outlive this.
|
||||
// `values` must be sorted.
|
||||
explicit SinglePassSearch(const std::vector<int>* values)
|
||||
: current_index_(0), values_(values) {}
|
||||
|
||||
// Scans forward in the vector looking for "value", updating the internal
|
||||
// position in to the vector.
|
||||
// Returns true iff the vector contains the given value at or after current
|
||||
// position.
|
||||
// Not thread-safe.
|
||||
bool ScanForValue(int value) {
|
||||
while (current_index_ < values_->size() &&
|
||||
(*values_)[current_index_] <= value) {
|
||||
if ((*values_)[current_index_] == value) {
|
||||
current_index_++;
|
||||
return true;
|
||||
}
|
||||
current_index_++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int current_index_;
|
||||
const std::vector<int>* values_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def) {
|
||||
const FunctionDef* function_def =
|
||||
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
|
||||
if (function_def == nullptr) {
|
||||
// The node def is not calling a function. Individual ops can be
|
||||
// run directly using on-demand mode, no need to create XlaLaunch
|
||||
// kernel for them.
|
||||
return false;
|
||||
}
|
||||
|
||||
// If kXlaCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaCompileAttr);
|
||||
if (it != node_def.attr().end()) {
|
||||
return it->second.b();
|
||||
}
|
||||
|
||||
// kXlaCompileAttr is not set on node_def, check if it is set on
|
||||
// FunctionDef.
|
||||
bool xla_compile = false;
|
||||
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
|
||||
node_def, kXlaCompileAttr, &xla_compile);
|
||||
if (!status.ok() || !xla_compile) {
|
||||
if (VLOG_IS_ON(3)) {
|
||||
if (!status.ok()) {
|
||||
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
|
||||
<< node_def.op() << ". status=" << status.ToString();
|
||||
} else {
|
||||
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||
// runtime, returns this function's body in `fbody` as well as the indices
|
||||
// of its constant and resource arguments.
|
||||
// `fbody` is owned by `flr`.
|
||||
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
||||
// They are sorted in ascending order on this function's return.
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
const FunctionBody** fbody,
|
||||
std::vector<int>* constant_arg_indices,
|
||||
std::vector<int>* resource_arg_indices) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If node_def is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
|
||||
*fbody = flr->GetFunctionBody(handle);
|
||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||
std::vector<bool> const_args(arg_types.size());
|
||||
// If we can't analyze the const args. Bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr, flr));
|
||||
|
||||
for (int i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
constant_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// There can be hundreds of resource variables. Reserve the space for them.
|
||||
// We don't reserve for constants above as they are usually few.
|
||||
resource_arg_indices->reserve(arg_types.size());
|
||||
for (int i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i] == DT_RESOURCE) {
|
||||
resource_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) {
|
||||
if (!CanCreateXlaKernel(*flr, node_def)) {
|
||||
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
|
||||
}
|
||||
|
||||
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
|
||||
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompilable_node_info;
|
||||
for (const auto& it : uncompilable_nodes_map) {
|
||||
for (const auto& info : it.second.second) {
|
||||
uncompilable_node_info.emplace_back(info);
|
||||
}
|
||||
}
|
||||
string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
node_def.ShortDebugString(), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:\n");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
string node_message =
|
||||
absl::StrCat("\t", node_info.name, ": ",
|
||||
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
|
||||
for (const auto& stack_frame : node_info.stack_trace) {
|
||||
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
|
||||
stack_frame.name, stack_frame.function_name);
|
||||
}
|
||||
absl::StrAppend(&message, node_message);
|
||||
}
|
||||
VLOG(1) << message;
|
||||
// node_def is calling a function that XLA can't compile.
|
||||
return errors::InvalidArgument(message);
|
||||
}
|
||||
|
||||
// Get function body, constant args, and resource args.
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
// Set input and output memory types.
|
||||
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
|
||||
// These indices are used only for optimization purposes. They allow us
|
||||
// to loop over constant_arg_indices and resource_arg_indices only once
|
||||
// while iterating over all the function arguments checking if it is a
|
||||
// resource or a constant.
|
||||
// The reason we optimized this code is because functions can have a lot of
|
||||
// captured arguments. For example, the backward pass of ResNet50 takes in all
|
||||
// 214 variables and a similar number of activations.
|
||||
SinglePassSearch constants_search(&constant_arg_indices);
|
||||
SinglePassSearch resources_search(&resource_arg_indices);
|
||||
for (int i = 0; i < fbody->arg_types.size(); ++i) {
|
||||
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
|
||||
// Compile-time constants and resource handles are expected to be in
|
||||
// host memory.
|
||||
input_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
// One might wonder, about the case where a compile-time constant argument
|
||||
// (which must be in host memory) is also used as an input into an op,
|
||||
// e.g. Add, that expects its inputs in device memory. Here is how it
|
||||
// works now.
|
||||
// First, what do we mean by "op expects an input in XYZ memory"?
|
||||
// There are two types of "ops" here: the tf2xla kernel and the HLO
|
||||
// computation it builds. The tf2xla kernel needs to retrieve the actual
|
||||
// numeric value of the compile-time constant tensors, so it really expects
|
||||
// them to be on in host memory. However, for other inputs, it refers to them
|
||||
// using xla::ComputationDataHandle, which is just a symbolic handle that
|
||||
// xla::ComputationBuilder assigns. How does this handle gets assigned for
|
||||
// constant arguments? Even constant arguments get an _Arg node in the graph
|
||||
// instatiated for Function compilation. The tf2xla kernel for constant _Arg
|
||||
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
|
||||
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
|
||||
// constant XlaLiteral is included in the HLO graph, and subsequently, in
|
||||
// the actual executable, which is copied to the device before being
|
||||
// executed. Thus, when this executable runs, the constant is available in
|
||||
// device memory.
|
||||
|
||||
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
|
||||
// in device memory except for resources.
|
||||
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
|
||||
for (int i = 0; i < fbody->ret_types.size(); ++i) {
|
||||
if (fbody->ret_types[i] == DT_RESOURCE) {
|
||||
output_memory_types[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
|
||||
// Create the kernel.
|
||||
NameAttrList function;
|
||||
function.set_name(node_def.op());
|
||||
*(function.mutable_attr()) = node_def.attr();
|
||||
|
||||
Device* dev = flr->device();
|
||||
Status s;
|
||||
OpKernelConstruction construction(
|
||||
DeviceType(dev->device_type()), dev,
|
||||
dev->GetAllocator(AllocatorAttributes()), &node_def,
|
||||
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
|
||||
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
|
||||
|
||||
*kernel = absl::make_unique<XlaLocalLaunchBase>(
|
||||
&construction, constant_arg_indices, resource_arg_indices, function);
|
||||
return s;
|
||||
}
|
||||
} // namespace tensorflow
|
39
tensorflow/compiler/jit/xla_kernel_creator_util.h
Normal file
39
tensorflow/compiler/jit/xla_kernel_creator_util.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class FunctionLibraryRuntime;
|
||||
class OpKernel;
|
||||
|
||||
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
|
||||
// true if 'node_def' is a call to a compilable function defined in 'flr',
|
||||
// with the kXlaCompileAttr set.
|
||||
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
||||
const NodeDef& node_def);
|
||||
|
||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
|
@ -42,6 +42,13 @@ namespace tensorflow {
|
||||
namespace {
|
||||
using xla::ScopedShapedBuffer;
|
||||
using xla::ShapedBuffer;
|
||||
|
||||
const char kPossibleNonVariableResourceHintMessage[] =
|
||||
"If the error is similar to `Trying to access resource using the wrong "
|
||||
"type`, this is likely because XLA only accepts Resource Variables as "
|
||||
"inputs by snapshotting their values. Other TensorFlow resource types like "
|
||||
"TensorList/TensorArray/Stack are not supported. Try removing non-variable "
|
||||
"resource inputs to XLA.";
|
||||
} // anonymous namespace
|
||||
|
||||
VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {}
|
||||
@ -88,7 +95,12 @@ static Status GetVariableInfosFromCtxInputs(
|
||||
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
|
||||
|
||||
std::vector<core::RefCountPtr<Var>> variables;
|
||||
TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables));
|
||||
|
||||
Status s = LookupResources(ctx, resource_handles, &variables);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
|
||||
return s;
|
||||
}
|
||||
|
||||
result->clear();
|
||||
result->reserve(variable_indices.size());
|
||||
@ -235,9 +247,32 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool MustAliasOutput(const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
int output_num) {
|
||||
xla::ShapeIndex output_index;
|
||||
if (input_output_alias.shape().IsTuple()) {
|
||||
output_index = {output_num};
|
||||
} else {
|
||||
DCHECK_EQ(output_num, 0)
|
||||
<< "output_num must be 0 for non-tuple shapes but is " << output_num;
|
||||
output_index = {};
|
||||
}
|
||||
if (input_output_alias.shape().tuple_shapes_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return input_output_alias.OutputHasAlias(output_index) &&
|
||||
input_output_alias.GetAliasedParameter(output_index).value().kind ==
|
||||
xla::HloInputOutputAliasConfig::kUserAlias;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix) {
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
@ -331,8 +366,16 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
DCHECK(output.buffer({output_num}).is_null())
|
||||
<< "Expected output buffer to be aliased, but it is not nil.";
|
||||
}
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
if (allocate_xla_tensors_) {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
return errors::Unimplemented(
|
||||
"Aliasing is not yet supported for allocate_xla_tensors_.");
|
||||
}
|
||||
Tensor* output_tensor;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
|
||||
@ -347,8 +390,18 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
CHECK_EQ(output_tensor->TotalBytes(), 0);
|
||||
}
|
||||
} else {
|
||||
bool is_aliased = false;
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
int xla_param = input_output_alias.GetAliasedParameter({output_num})
|
||||
.value()
|
||||
.parameter_number;
|
||||
DCHECK(arg_ptrs_[xla_param] != nullptr);
|
||||
buffer = arg_ptrs_[xla_param]->buffer({});
|
||||
is_aliased = true;
|
||||
}
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
ctx->expected_output_dtype(i), shape,
|
||||
/*unref_buffer=*/!is_aliased, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
@ -412,7 +465,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
write.type, write.shape, buffer, allocator);
|
||||
write.type, write.shape, /*unref_buffer=*/true, buffer, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
}
|
||||
++output_num;
|
||||
|
@ -149,10 +149,10 @@ class XlaComputationLaunchContext {
|
||||
//
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
|
||||
// missing and adjusts input indices accordingly.
|
||||
Status PopulateOutputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* kernel,
|
||||
xla::ScopedShapedBuffer output,
|
||||
int missing_ctx_input_prefix);
|
||||
Status PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias);
|
||||
|
||||
// Return the argument list. Only valid after PopulateInputs() has been
|
||||
// called.
|
||||
@ -193,12 +193,15 @@ class XlaTensorBuffer : public TensorBuffer {
|
||||
}
|
||||
|
||||
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
|
||||
se::DeviceMemoryBase buffer, Allocator* allocator) {
|
||||
bool unref_buffer, se::DeviceMemoryBase buffer,
|
||||
Allocator* allocator) {
|
||||
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
|
||||
auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
|
||||
buffer.size(), allocator);
|
||||
Tensor t(dtype, shape, tensor_buffer);
|
||||
tensor_buffer->Unref();
|
||||
if (unref_buffer) {
|
||||
tensor_buffer->Unref();
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
|
@ -19,10 +19,23 @@ filegroup(
|
||||
srcs = glob(["**/*.td"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "op_name_mapper",
|
||||
srcs = ["op_name_mapper.cc"],
|
||||
hdrs = ["op_name_mapper.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_mlir_opt_main",
|
||||
srcs = ["tf_mlir_opt_main.cc"],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":init_mlir",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
|
||||
@ -31,12 +44,14 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/xla",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:AffineDialectRegistration",
|
||||
"@local_config_mlir//:MlirOptLib",
|
||||
@ -49,16 +64,29 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "init_mlir",
|
||||
srcs = ["init_mlir.cc"],
|
||||
hdrs = ["init_mlir.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-opt",
|
||||
deps = [
|
||||
":tf_mlir_opt_main",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-mlir-translate",
|
||||
srcs = ["tf_mlir_translate_main.cc"],
|
||||
deps = [
|
||||
":init_mlir",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
@ -66,12 +94,14 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:TranslateClParser",
|
||||
"@local_config_mlir//:Translation",
|
||||
"@local_config_mlir//:tools/mlir-translate/mlir-translate",
|
||||
],
|
||||
)
|
||||
|
||||
|
45
tensorflow/compiler/mlir/init_mlir.cc
Normal file
45
tensorflow/compiler/mlir/init_mlir.cc
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
InitMlir::InitMlir(int *argc, char ***argv) : init_llvm_(*argc, *argv) {
|
||||
constexpr char kSeparator[] = "--";
|
||||
|
||||
// Find index of separator between two sets of flags.
|
||||
int pass_remainder = 1;
|
||||
bool split = false;
|
||||
for (int i = 0; i < *argc; ++i) {
|
||||
if (llvm::StringRef((*argv)[i]) == kSeparator) {
|
||||
pass_remainder = i;
|
||||
*argc -= (i + 1);
|
||||
split = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::port::InitMain((*argv)[0], &pass_remainder, argv);
|
||||
if (split) {
|
||||
*argc += pass_remainder;
|
||||
(*argv)[1] = (*argv)[0];
|
||||
++*argv;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user