Merge branch 'master' of https://github.com/tensorflow/tensorflow into tensorrt to fix some of the failed tests.
This commit is contained in:
commit
149fc8dbd6
@ -4,7 +4,7 @@ https://stackoverflow.com/questions/tagged/tensorflow
|
|||||||
|
|
||||||
If you open a GitHub issue, here is our policy:
|
If you open a GitHub issue, here is our policy:
|
||||||
|
|
||||||
1. It must be a bug or a feature request.
|
1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead).
|
||||||
2. The form below must be filled out.
|
2. The form below must be filled out.
|
||||||
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues).
|
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues).
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
|
| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
|
||||||
|-----------------|---------------------|------------------|-------------------|---------------|
|
|-----------------|---------------------|------------------|-------------------|---------------|
|
||||||
| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) |
|
| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
|
||||||
|
|
||||||
**TensorFlow** is an open source software library for numerical computation using
|
**TensorFlow** is an open source software library for numerical computation using
|
||||||
data flow graphs. The graph nodes represent mathematical operations, while
|
data flow graphs. The graph nodes represent mathematical operations, while
|
||||||
|
25
RELEASE.md
25
RELEASE.md
@ -1,18 +1,39 @@
|
|||||||
# Release 1.5.0
|
# Release 1.5.0
|
||||||
|
|
||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
* Prebuilt binaries are now built against CUDA 9 and cuDNN 7.
|
* Prebuilt binaries are now built against CUDA 9.0 and cuDNN 7.
|
||||||
* Our Linux binaries are built using ubuntu 16 containers, potentially
|
* Our Linux binaries are built using ubuntu 16 containers, potentially
|
||||||
introducing glibc incompatibility issues with ubuntu 14.
|
introducing glibc incompatibility issues with ubuntu 14.
|
||||||
* Starting from 1.6 release, our prebuilt binaries will use AVX instructions.
|
* Starting from 1.6 release, our prebuilt binaries will use AVX instructions.
|
||||||
This may break TF on older CPUs.
|
This may break TF on older CPUs.
|
||||||
|
|
||||||
|
## Known Bugs
|
||||||
|
* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or
|
||||||
|
`CUDA_ILLEGAL_ADDRESS` failures.
|
||||||
|
|
||||||
|
Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9
|
||||||
|
and CUDA 9.1 sometimes does not properly compute the carry bit when
|
||||||
|
decomposing 64-bit address calculations with large offsets (e.g. `load [x +
|
||||||
|
large_constant]`) into 32-bit arithmetic in SASS.
|
||||||
|
|
||||||
|
As a result, these versions of `ptxas` miscompile most XLA programs which use
|
||||||
|
more than 4GB of temp memory. This results in garbage results and/or
|
||||||
|
`CUDA_ERROR_ILLEGAL_ADDRESS` failures.
|
||||||
|
|
||||||
|
A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a
|
||||||
|
fix for CUDA 9.0.x. Until the fix is available, the only workaround is to
|
||||||
|
[downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x
|
||||||
|
or disable XLA:GPU.
|
||||||
|
|
||||||
|
TensorFlow will print a warning if you use XLA:GPU with a known-bad version of
|
||||||
|
CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122.
|
||||||
|
|
||||||
## Major Features And Improvements
|
## Major Features And Improvements
|
||||||
* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager)
|
* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager)
|
||||||
preview version is now available.
|
preview version is now available.
|
||||||
* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite)
|
* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite)
|
||||||
dev preview is now available.
|
dev preview is now available.
|
||||||
* CUDA 9 and cuDNN 7 support.
|
* CUDA 9.0 and cuDNN 7 support.
|
||||||
* Accelerated Linear Algebra (XLA):
|
* Accelerated Linear Algebra (XLA):
|
||||||
* Add `complex64` support to XLA compiler.
|
* Add `complex64` support to XLA compiler.
|
||||||
* `bfloat` support is now added to XLA infrastructure.
|
* `bfloat` support is now added to XLA infrastructure.
|
||||||
|
@ -298,7 +298,7 @@ def get_var(environ_cp,
|
|||||||
System".
|
System".
|
||||||
enabled_by_default: boolean for default behavior.
|
enabled_by_default: boolean for default behavior.
|
||||||
question: optional string for how to ask for user input.
|
question: optional string for how to ask for user input.
|
||||||
yes_reply: optionanl string for reply when feature is enabled.
|
yes_reply: optional string for reply when feature is enabled.
|
||||||
no_reply: optional string for reply when feature is disabled.
|
no_reply: optional string for reply when feature is disabled.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -411,7 +411,7 @@ def set_action_env_var(environ_cp,
|
|||||||
System".
|
System".
|
||||||
enabled_by_default: boolean for default behavior.
|
enabled_by_default: boolean for default behavior.
|
||||||
question: optional string for how to ask for user input.
|
question: optional string for how to ask for user input.
|
||||||
yes_reply: optionanl string for reply when feature is enabled.
|
yes_reply: optional string for reply when feature is enabled.
|
||||||
no_reply: optional string for reply when feature is disabled.
|
no_reply: optional string for reply when feature is disabled.
|
||||||
"""
|
"""
|
||||||
var = int(
|
var = int(
|
||||||
@ -1354,6 +1354,7 @@ def main():
|
|||||||
environ_cp['TF_NEED_GCP'] = '0'
|
environ_cp['TF_NEED_GCP'] = '0'
|
||||||
environ_cp['TF_NEED_HDFS'] = '0'
|
environ_cp['TF_NEED_HDFS'] = '0'
|
||||||
environ_cp['TF_NEED_JEMALLOC'] = '0'
|
environ_cp['TF_NEED_JEMALLOC'] = '0'
|
||||||
|
environ_cp['TF_NEED_KAFKA'] = '0'
|
||||||
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
|
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
|
||||||
environ_cp['TF_NEED_COMPUTECPP'] = '0'
|
environ_cp['TF_NEED_COMPUTECPP'] = '0'
|
||||||
environ_cp['TF_NEED_OPENCL'] = '0'
|
environ_cp['TF_NEED_OPENCL'] = '0'
|
||||||
@ -1372,6 +1373,8 @@ def main():
|
|||||||
'with_hdfs_support', True, 'hdfs')
|
'with_hdfs_support', True, 'hdfs')
|
||||||
set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
|
set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
|
||||||
'with_s3_support', True, 's3')
|
'with_s3_support', True, 's3')
|
||||||
|
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
|
||||||
|
'with_kafka_support', False, 'kafka')
|
||||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
||||||
False, 'xla')
|
False, 'xla')
|
||||||
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
|
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
|
||||||
|
@ -211,6 +211,12 @@ config_setting(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "with_kafka_support",
|
||||||
|
define_values = {"with_kafka_support": "true"},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
# Crosses between platforms and file system libraries not supported on those
|
# Crosses between platforms and file system libraries not supported on those
|
||||||
# platforms due to limitations in nested select() statements.
|
# platforms due to limitations in nested select() statements.
|
||||||
config_setting(
|
config_setting(
|
||||||
@ -544,8 +550,10 @@ filegroup(
|
|||||||
"//tensorflow/contrib/predictor:all_files",
|
"//tensorflow/contrib/predictor:all_files",
|
||||||
"//tensorflow/contrib/py2tf:all_files",
|
"//tensorflow/contrib/py2tf:all_files",
|
||||||
"//tensorflow/contrib/py2tf/converters:all_files",
|
"//tensorflow/contrib/py2tf/converters:all_files",
|
||||||
|
"//tensorflow/contrib/py2tf/impl:all_files",
|
||||||
"//tensorflow/contrib/py2tf/pyct:all_files",
|
"//tensorflow/contrib/py2tf/pyct:all_files",
|
||||||
"//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
|
"//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
|
||||||
|
"//tensorflow/contrib/py2tf/utils:all_files",
|
||||||
"//tensorflow/contrib/quantize:all_files",
|
"//tensorflow/contrib/quantize:all_files",
|
||||||
"//tensorflow/contrib/receptive_field:all_files",
|
"//tensorflow/contrib/receptive_field:all_files",
|
||||||
"//tensorflow/contrib/reduce_slice_ops:all_files",
|
"//tensorflow/contrib/reduce_slice_ops:all_files",
|
||||||
|
239
tensorflow/SECURITY.md
Normal file
239
tensorflow/SECURITY.md
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
# Using TensorFlow Securely
|
||||||
|
|
||||||
|
This document discusses how to safely deal with untrusted programs (models or
|
||||||
|
model parameters), and input data. Below, we also provide guidelines on how to
|
||||||
|
report vulnerabilities in TensorFlow.
|
||||||
|
|
||||||
|
## TensorFlow models are programs
|
||||||
|
|
||||||
|
TensorFlow's runtime system interprets and executes programs. What machine
|
||||||
|
learning practitioners term
|
||||||
|
[**models**](https://developers.google.com/machine-learning/glossary/#model) are
|
||||||
|
expressed as programs that TensorFlow executes. TensorFlow programs are encoded
|
||||||
|
as computation
|
||||||
|
[**graphs**](https://developers.google.com/machine-learning/glossary/#graph).
|
||||||
|
The model's parameters are often stored separately in **checkpoints**.
|
||||||
|
|
||||||
|
At runtime, TensorFlow executes the computation graph using the parameters
|
||||||
|
provided. Note that the behavior of the computation graph may change
|
||||||
|
depending on the parameters provided. TensorFlow itself is not a sandbox. When
|
||||||
|
executing the computation graph, TensorFlow may read and write files, send and
|
||||||
|
receive data over the network, and even spawn additional processes. All these
|
||||||
|
tasks are performed with the permissions of the TensorFlow process. Allowing
|
||||||
|
for this flexibility makes for a powerful machine learning platform,
|
||||||
|
but it has implications for security.
|
||||||
|
|
||||||
|
The computation graph may also accept **inputs**. Those inputs are the
|
||||||
|
data you supply to TensorFlow to train a model, or to use a model to run
|
||||||
|
inference on the data.
|
||||||
|
|
||||||
|
**TensorFlow models are programs, and need to be treated as such from a security
|
||||||
|
perspective.**
|
||||||
|
|
||||||
|
## Running untrusted models
|
||||||
|
|
||||||
|
As a general rule: **Always** execute untrusted models inside a sandbox (e.g.,
|
||||||
|
[nsjail](https://github.com/google/nsjail)).
|
||||||
|
|
||||||
|
There are several ways in which a model could become untrusted. Obviously, if an
|
||||||
|
untrusted party supplies TensorFlow kernels, arbitrary code may be executed.
|
||||||
|
The same is true if the untrusted party provides Python code, such as the
|
||||||
|
Python code that generates TensorFlow graphs.
|
||||||
|
|
||||||
|
Even if the untrusted party only supplies the serialized computation
|
||||||
|
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
|
||||||
|
set of computation primitives available to TensorFlow is powerful enough that
|
||||||
|
you should assume that the TensorFlow process effectively executes arbitrary
|
||||||
|
code. One common solution is to whitelist only a few safe Ops. While this is
|
||||||
|
possible in theory, we still recommend you sandbox the execution.
|
||||||
|
|
||||||
|
It depends on the computation graph whether a user provided checkpoint is safe.
|
||||||
|
It is easily possible to create computation graphs in which malicious
|
||||||
|
checkpoints can trigger unsafe behavior. For example, consider a graph that
|
||||||
|
contains a `tf.cond` depending on the value of a `tf.Variable`. One branch of
|
||||||
|
the `tf.cond` is harmless, but the other is unsafe. Since the `tf.Variable` is
|
||||||
|
stored in the checkpoint, whoever provides the checkpoint now has the ability to
|
||||||
|
trigger unsafe behavior, even though the graph is not under their control.
|
||||||
|
|
||||||
|
In other words, graphs can contain vulnerabilities of their own. To allow users
|
||||||
|
to provide checkpoints to a model you run on their behalf (e.g., in order to
|
||||||
|
compare model quality for a fixed model architecture), you must carefully audit
|
||||||
|
your model, and we recommend you run the TensorFlow process in a sandbox.
|
||||||
|
|
||||||
|
## Accepting untrusted Inputs
|
||||||
|
|
||||||
|
It is possible to write models that are secure in a sense that they can safely
|
||||||
|
process untrusted inputs assuming there are no bugs. There are two main reasons
|
||||||
|
to not rely on this: first, it is easy to write models which must not be exposed
|
||||||
|
to untrusted inputs, and second, there are bugs in any software system of
|
||||||
|
sufficient complexity. Letting users control inputs could allow them to trigger
|
||||||
|
bugs either in TensorFlow or in dependent libraries.
|
||||||
|
|
||||||
|
In general, it is good practice to isolate parts of any system which is exposed
|
||||||
|
to untrusted (e.g., user-provided) inputs in a sandbox.
|
||||||
|
|
||||||
|
A useful analogy to how any TensorFlow graph is executed is any interpreted
|
||||||
|
programming language, such as Python. While it is possible to write secure
|
||||||
|
Python code which can be exposed to user supplied inputs (by, e.g., carefully
|
||||||
|
quoting and sanitizing input strings, size-checking input blobs, etc.), it is
|
||||||
|
very easy to write Python programs which are insecure. Even secure Python code
|
||||||
|
could be rendered insecure by a bug in the Python interpreter, or in a bug in a
|
||||||
|
Python library used (e.g.,
|
||||||
|
[this one](https://www.cvedetails.com/cve/CVE-2017-12852/)).
|
||||||
|
|
||||||
|
## Running a TensorFlow server
|
||||||
|
|
||||||
|
TensorFlow is a platform for distributed computing, and as such there is a
|
||||||
|
TensorFlow server (`tf.train.Server`). **The TensorFlow server is meant for
|
||||||
|
internal communication only. It is not built for use in an untrusted network.**
|
||||||
|
|
||||||
|
For performance reasons, the default TensorFlow server does not include any
|
||||||
|
authorization protocol and sends messages unencrypted. It accepts connections
|
||||||
|
from anywhere, and executes the graphs it is sent without performing any checks.
|
||||||
|
Therefore, if you run a `tf.train.Server` in your network, anybody with
|
||||||
|
access to the network can execute what you should consider arbitrary code with
|
||||||
|
the privileges of the process running the `tf.train.Server`.
|
||||||
|
|
||||||
|
When running distributed TensorFlow, you must isolate the network in which the
|
||||||
|
cluster lives. Cloud providers provide instructions for setting up isolated
|
||||||
|
networks, which are sometimes branded as "virtual private cloud." Refer to the
|
||||||
|
instructions for
|
||||||
|
[GCP](https://cloud.google.com/compute/docs/networks-and-firewalls) and
|
||||||
|
[AWS](https://aws.amazon.com/vpc/)) for details.
|
||||||
|
|
||||||
|
Note that `tf.train.Server` is different from the server created by
|
||||||
|
`tensorflow/serving` (the default binary for which is called `ModelServer`).
|
||||||
|
By default, `ModelServer` also has no built-in mechanism for authentication.
|
||||||
|
Connecting it to an untrusted network allows anyone on this network to run the
|
||||||
|
graphs known to the `ModelServer`. This means that an attacker may run
|
||||||
|
graphs using untrusted inputs as described above, but they would not be able to
|
||||||
|
execute arbitrary graphs. It is possible to safely expose a `ModelServer`
|
||||||
|
directly to an untrusted network, **but only if the graphs it is configured to
|
||||||
|
use have been carefully audited to be safe**.
|
||||||
|
|
||||||
|
Similar to best practices for other servers, we recommend running any
|
||||||
|
`ModelServer` with appropriate privileges (i.e., using a separate user with
|
||||||
|
reduced permisisons). In the spirit of defense in depth, we recommend
|
||||||
|
authenticating requests to any TensorFlow server connected to an untrusted
|
||||||
|
network, as well as sandboxing the server to minimize the adverse effects of
|
||||||
|
any breach.
|
||||||
|
|
||||||
|
## Vulnerabilities in TensorFlow
|
||||||
|
|
||||||
|
TensorFlow is a large and complex system. It also depends on a large set of
|
||||||
|
third party libraries (e.g., `numpy`, `libjpeg-turbo`, PNG parsers, `protobuf`).
|
||||||
|
It is possible that TensorFlow or its dependent libraries contain
|
||||||
|
vulnerabilities that would allow triggering unexpected or dangerous behavior
|
||||||
|
with specially crafted inputs.
|
||||||
|
|
||||||
|
### What is a vulnerability?
|
||||||
|
|
||||||
|
Given TensorFlow's flexibility, it is possible to specify computation graphs
|
||||||
|
which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models
|
||||||
|
can perform arbitrary computations means that they may read and write files,
|
||||||
|
communicate via the network, produce deadlocks and infinite loops, or run out
|
||||||
|
of memory. It is only when these behaviors are outside the specifications of the
|
||||||
|
operations involved that such behavior is a vulnerability.
|
||||||
|
|
||||||
|
A `FileWriter` writing a file is not unexpected behavior and therefore is not a
|
||||||
|
vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution
|
||||||
|
**is** a vulnerability.
|
||||||
|
|
||||||
|
This is more subtle from a system perspective. For example, it is easy to cause
|
||||||
|
a TensorFlow process to try to allocate more memory than available by specifying
|
||||||
|
a computation graph containing an ill-considered `tf.tile` operation. TensorFlow
|
||||||
|
should exit cleanly in this case (it would raise an exception in Python, or
|
||||||
|
return an error `Status` in C++). However, if the surrounding system is not
|
||||||
|
expecting the possibility, such behavior could be used in a denial of service
|
||||||
|
attack (or worse). Because TensorFlow behaves correctly, this is not a
|
||||||
|
vulnerability in TensorFlow (although it would be a vulnerability of this
|
||||||
|
hypothetical system).
|
||||||
|
|
||||||
|
As a general rule, it is incorrect behavior for Tensorflow to access memory it
|
||||||
|
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
|
||||||
|
such behaviors constitute a vulnerability.
|
||||||
|
|
||||||
|
One of the most critical parts of any system is input handling. If malicious
|
||||||
|
input can trigger side effects or incorrect behavior, this is a bug, and likely
|
||||||
|
a vulnerability.
|
||||||
|
|
||||||
|
### Reporting vulnerabilities
|
||||||
|
|
||||||
|
Please email reports about any security related issues you find to
|
||||||
|
`security@tensorflow.org`. This mail is delivered to a small security team. Your
|
||||||
|
email will be acknowledged within one business day, and you'll receive a more
|
||||||
|
detailed response to your email within 7 days indicating the next steps in
|
||||||
|
handling your report. For critical problems, you may encrypt your report (see
|
||||||
|
below).
|
||||||
|
|
||||||
|
Please use a descriptive subject line for your report email. After the initial
|
||||||
|
reply to your report, the security team will endeavor to keep you informed of
|
||||||
|
the progress being made towards a fix and announcement.
|
||||||
|
|
||||||
|
If you believe that an existing (public) issue is security-related, please send
|
||||||
|
an email to `security@tensorflow.org`. The email should include the issue ID and
|
||||||
|
a short description of why it should be handled according to this security
|
||||||
|
policy.
|
||||||
|
|
||||||
|
Once an issue is reported, TensorFlow uses the following disclosure process:
|
||||||
|
|
||||||
|
* When a report is received, we confirm the issue and determine its severity.
|
||||||
|
* If we know of specific third-party services or software based on TensorFlow
|
||||||
|
that require mitigation before publication, those projects will be notified.
|
||||||
|
* An advisory is prepared (but not published) which details the problem and
|
||||||
|
steps for mitigation.
|
||||||
|
* Wherever possible, fixes are prepared for the last minor release of the two
|
||||||
|
latest major releases, as well as the master branch. We will attempt to
|
||||||
|
commit these fixes as soon as possible, and as close together as
|
||||||
|
possible.
|
||||||
|
* Patch releases are published for all fixed released versions, a
|
||||||
|
notification is sent to discuss@tensorflow.org, and the advisory is published.
|
||||||
|
|
||||||
|
Past security advisories are listed below. We credit reporters for identifying
|
||||||
|
security issues, although we keep your name confidential if you request it.
|
||||||
|
|
||||||
|
#### Encryption key for `security@tensorflow.org`
|
||||||
|
|
||||||
|
If your disclosure is extremely sensitive, you may choose to encrypt your
|
||||||
|
report using the key below. Please only use this for critical security
|
||||||
|
reports.
|
||||||
|
|
||||||
|
```
|
||||||
|
-----BEGIN PGP PUBLIC KEY BLOCK-----
|
||||||
|
|
||||||
|
mQENBFpqdzwBCADTeAHLNEe9Vm77AxhmGP+CdjlY84O6DouOCDSq00zFYdIU/7aI
|
||||||
|
LjYwhEmDEvLnRCYeFGdIHVtW9YrVktqYE9HXVQC7nULU6U6cvkQbwHCdrjaDaylP
|
||||||
|
aJUXkNrrxibhx9YYdy465CfusAaZ0aM+T9DpcZg98SmsSml/HAiiY4mbg/yNVdPs
|
||||||
|
SEp/Ui4zdIBNNs6at2gGZrd4qWhdM0MqGJlehqdeUKRICE/mdedXwsWLM8AfEA0e
|
||||||
|
OeTVhZ+EtYCypiF4fVl/NsqJ/zhBJpCx/1FBI1Uf/lu2TE4eOS1FgmIqb2j4T+jY
|
||||||
|
e+4C8kGB405PAC0n50YpOrOs6k7fiQDjYmbNABEBAAG0LVRlbnNvckZsb3cgU2Vj
|
||||||
|
dXJpdHkgPHNlY3VyaXR5QHRlbnNvcmZsb3cub3JnPokBTgQTAQgAOBYhBEkvXzHm
|
||||||
|
gOJBnwP4Wxnef3wVoM2yBQJaanc8AhsDBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheA
|
||||||
|
AAoJEBnef3wVoM2yNlkIAICqetv33MD9W6mPAXH3eon+KJoeHQHYOuwWfYkUF6CC
|
||||||
|
o+X2dlPqBSqMG3bFuTrrcwjr9w1V8HkNuzzOJvCm1CJVKaxMzPuXhBq5+DeT67+a
|
||||||
|
T/wK1L2R1bF0gs7Pp40W3np8iAFEh8sgqtxXvLGJLGDZ1Lnfdprg3HciqaVAiTum
|
||||||
|
HBFwszszZZ1wAnKJs5KVteFN7GSSng3qBcj0E0ql2nPGEqCVh+6RG/TU5C8gEsEf
|
||||||
|
3DX768M4okmFDKTzLNBm+l08kkBFt+P43rNK8dyC4PXk7yJa93SmS/dlK6DZ16Yw
|
||||||
|
2FS1StiZSVqygTW59rM5XNwdhKVXy2mf/RtNSr84gSi5AQ0EWmp3PAEIALInfBLR
|
||||||
|
N6fAUGPFj+K3za3PeD0fWDijlC9f4Ety/icwWPkOBdYVBn0atzI21thPRbfuUxfe
|
||||||
|
zr76xNNrtRRlbDSAChA1J5T86EflowcQor8dNC6fS+oHFCGeUjfEAm16P6mGTo0p
|
||||||
|
osdG2XnnTHOOEFbEUeWOwR/zT0QRaGGknoy2pc4doWcJptqJIdTl1K8xyBieik/b
|
||||||
|
nSoClqQdZJa4XA3H9G+F4NmoZGEguC5GGb2P9NHYAJ3MLHBHywZip8g9oojIwda+
|
||||||
|
OCLL4UPEZ89cl0EyhXM0nIAmGn3Chdjfu3ebF0SeuToGN8E1goUs3qSE77ZdzIsR
|
||||||
|
BzZSDFrgmZH+uP0AEQEAAYkBNgQYAQgAIBYhBEkvXzHmgOJBnwP4Wxnef3wVoM2y
|
||||||
|
BQJaanc8AhsMAAoJEBnef3wVoM2yX4wIALcYZbQhSEzCsTl56UHofze6C3QuFQIH
|
||||||
|
J4MIKrkTfwiHlCujv7GASGU2Vtis5YEyOoMidUVLlwnebE388MmaJYRm0fhYq6lP
|
||||||
|
A3vnOCcczy1tbo846bRdv012zdUA+wY+mOITdOoUjAhYulUR0kiA2UdLSfYzbWwy
|
||||||
|
7Obq96Jb/cPRxk8jKUu2rqC/KDrkFDtAtjdIHh6nbbQhFuaRuWntISZgpIJxd8Bt
|
||||||
|
Gwi0imUVd9m9wZGuTbDGi6YTNk0GPpX5OMF5hjtM/objzTihSw9UN+65Y/oSQM81
|
||||||
|
v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
|
||||||
|
=CDME
|
||||||
|
-----END PGP PUBLIC KEY BLOCK-----
|
||||||
|
```
|
||||||
|
|
||||||
|
### Known vulnerabilities
|
||||||
|
|
||||||
|
| Type | Versions affected | Reported by | Additional Information |
|
||||||
|
|------|:-----------------:|---------------------------------------|
|
||||||
|
| out of bounds read| <=1.4 | @zhangbo5891001 | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
|
||||||
|
|
@ -195,10 +195,10 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
|||||||
reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
|
reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
|
||||||
// TF_STRING and TF_RESOURCE tensors have a different representation in
|
// TF_STRING and TF_RESOURCE tensors have a different representation in
|
||||||
// TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
|
// TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
|
||||||
// (any alignement requirements will be taken care of by TF_TensorToTensor
|
// (any alignment requirements will be taken care of by TF_TensorToTensor
|
||||||
// and TF_TensorFromTensor).
|
// and TF_TensorFromTensor).
|
||||||
//
|
//
|
||||||
// Other types have the same represntation, so copy only if it is safe to do
|
// Other types have the same representation, so copy only if it is safe to do
|
||||||
// so.
|
// so.
|
||||||
buf->data_ = allocate_tensor("TF_NewTensor", len);
|
buf->data_ = allocate_tensor("TF_NewTensor", len);
|
||||||
std::memcpy(buf->data_, data, len);
|
std::memcpy(buf->data_, data, len);
|
||||||
@ -2144,7 +2144,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph,
|
|||||||
opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
|
opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TOOD(skyewm): change to OutputTensor
|
// TODO(skyewm): change to OutputTensor
|
||||||
tensorflow::ImportGraphDefResults results;
|
tensorflow::ImportGraphDefResults results;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
|
ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
|
||||||
|
@ -46,6 +46,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
"//tensorflow/core:core_cpu_lib",
|
"//tensorflow/core:core_cpu_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
@ -85,15 +85,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Context* ret = new TFE_Context(session);
|
return new TFE_Context(*opts, session);
|
||||||
ret->policy = opts->policy;
|
|
||||||
ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime(
|
|
||||||
ret->session->device_mgr, opts->session_options.options.env,
|
|
||||||
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {}));
|
|
||||||
ret->rendezvous =
|
|
||||||
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
|
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
|
||||||
@ -261,15 +253,6 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
|||||||
|
|
||||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||||
|
|
||||||
static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device,
|
|
||||||
TF_Status* status) {
|
|
||||||
// Questionable heuristic: Place the op on the same device as the first input
|
|
||||||
// placed outside of host memory?
|
|
||||||
if (IsCPU(op->device) && !IsCPU(device)) {
|
|
||||||
op->device = device;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||||
tensorflow::Device* d = nullptr;
|
tensorflow::Device* d = nullptr;
|
||||||
if (device_name != nullptr && strlen(device_name) > 0) {
|
if (device_name != nullptr && strlen(device_name) > 0) {
|
||||||
@ -277,11 +260,24 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
|||||||
op->ctx->session->device_mgr->LookupDevice(device_name, &d);
|
op->ctx->session->device_mgr->LookupDevice(device_name, &d);
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
}
|
}
|
||||||
TFE_OpSetDeviceHelper(op, d, status);
|
op->device = d;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||||
|
tensorflow::Device* device =
|
||||||
|
(op->device == nullptr) ? op->ctx->devices()[0] : op->device;
|
||||||
|
return device->name().c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||||
TFE_OpSetDeviceHelper(op, h->d, status);
|
// Questionable heuristic ...
|
||||||
|
//
|
||||||
|
// Motivation: After an 'op' is placed on GPU because some of its earlier
|
||||||
|
// inputs are on GPU, we want to keep the 'op' there, even if some later
|
||||||
|
// inputs of it are not on GPU.
|
||||||
|
if (IsCPU(op->device) && !IsCPU(h->d)) {
|
||||||
|
op->device = h->d;
|
||||||
|
}
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
op->inputs.push_back(h->t);
|
op->inputs.push_back(h->t);
|
||||||
op->input_devices.push_back(h->d);
|
op->input_devices.push_back(h->d);
|
||||||
@ -298,7 +294,7 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
|||||||
return TF_ATTR_INT; // The compiler requires that we return something.
|
return TF_ATTR_INT; // The compiler requires that we return something.
|
||||||
}
|
}
|
||||||
status->status =
|
status->status =
|
||||||
tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list);
|
tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,6 +154,9 @@ TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
|
|||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
|
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
// The returned string remains valid throughout the lifetime of 'op'.
|
||||||
|
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
|
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
struct TFE_ContextOptions {
|
struct TFE_ContextOptions {
|
||||||
TF_SessionOptions session_options;
|
TF_SessionOptions session_options;
|
||||||
@ -43,9 +44,15 @@ struct TFE_ContextOptions {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_Context {
|
struct TFE_Context {
|
||||||
explicit TFE_Context(TF_Session* s) : session(s) {}
|
explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s)
|
||||||
|
: policy(opts.policy),
|
||||||
|
session(s),
|
||||||
|
rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)),
|
||||||
|
pflr(new tensorflow::ProcessFunctionLibraryRuntime(
|
||||||
|
session->device_mgr, opts.session_options.options.env,
|
||||||
|
TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {}
|
||||||
|
|
||||||
TFE_ContextDevicePlacementPolicy policy;
|
const TFE_ContextDevicePlacementPolicy policy;
|
||||||
|
|
||||||
// Note: we cannot use C++11 thread_local here as there is no concept of a
|
// Note: we cannot use C++11 thread_local here as there is no concept of a
|
||||||
// thread-local-object-local variable in C++11.
|
// thread-local-object-local variable in C++11.
|
||||||
@ -54,8 +61,8 @@ struct TFE_Context {
|
|||||||
thread_local_policies GUARDED_BY(policy_map_mu);
|
thread_local_policies GUARDED_BY(policy_map_mu);
|
||||||
|
|
||||||
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
|
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
|
||||||
TF_Session* session;
|
TF_Session* const session;
|
||||||
tensorflow::Rendezvous* rendezvous;
|
tensorflow::Rendezvous* const rendezvous;
|
||||||
|
|
||||||
tensorflow::mutex functions_mu;
|
tensorflow::mutex functions_mu;
|
||||||
tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
|
tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
|
||||||
@ -64,14 +71,14 @@ struct TFE_Context {
|
|||||||
// One FunctionLibraryRuntime per device.
|
// One FunctionLibraryRuntime per device.
|
||||||
// func_libs[i] is the FunctionLibraryRuntime corresponding to
|
// func_libs[i] is the FunctionLibraryRuntime corresponding to
|
||||||
// session->devices[i].
|
// session->devices[i].
|
||||||
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
|
const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
|
||||||
|
|
||||||
tensorflow::mutex cache_mu;
|
tensorflow::mutex cache_mu;
|
||||||
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
|
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
|
||||||
tensorflow::Fprint128Hasher>
|
tensorflow::Fprint128Hasher>
|
||||||
kernel_cache GUARDED_BY(cache_mu);
|
kernel_cache GUARDED_BY(cache_mu);
|
||||||
|
|
||||||
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
|
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const {
|
||||||
return pflr->GetFLR(d->name());
|
return pflr->GetFLR(d->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,6 +107,8 @@ struct TFE_TensorHandle {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_Op {
|
struct TFE_Op {
|
||||||
|
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
|
||||||
|
// primitive operation.
|
||||||
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
|
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
|
||||||
: ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
|
: ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
|
||||||
|
|
||||||
|
@ -60,6 +60,31 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
|||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If there is a GPU device, returns true and sets 'gpu_device_name'
|
||||||
|
// accordingly.
|
||||||
|
bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
const int num_devices = TF_DeviceListCount(devices);
|
||||||
|
for (int i = 0; i < num_devices; ++i) {
|
||||||
|
const string device_type(TF_DeviceListType(devices, i, status.get()));
|
||||||
|
CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
const string device_name(TF_DeviceListName(devices, i, status.get()));
|
||||||
|
CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||||
|
if (device_type == "GPU") {
|
||||||
|
*gpu_device_name = device_name;
|
||||||
|
LOG(INFO) << "Found GPU device " << device_name;
|
||||||
|
TF_DeleteDeviceList(devices);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TF_DeleteDeviceList(devices);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
void BM_InitOp(int iters) {
|
void BM_InitOp(int iters) {
|
||||||
tensorflow::testing::StopTiming();
|
tensorflow::testing::StopTiming();
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
@ -288,22 +313,15 @@ TEST(CAPI, TensorHandleSilentCopy) {
|
|||||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
const int num_devices = TF_DeviceListCount(devices);
|
|
||||||
|
|
||||||
// Disable the test if no GPU is present.
|
// Disable the test if no GPU is present.
|
||||||
if (num_devices > 1) {
|
string gpu_device_name;
|
||||||
const int device_to_use = 1;
|
if (GetGPUDeviceName(ctx, &gpu_device_name)) {
|
||||||
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
|
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||||
|
|
||||||
TFE_TensorHandle* hgpu =
|
|
||||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||||
TFE_OpSetDevice(matmul, name.c_str(), status.get());
|
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
@ -314,7 +332,6 @@ TEST(CAPI, TensorHandleSilentCopy) {
|
|||||||
TFE_DeleteTensorHandle(hgpu);
|
TFE_DeleteTensorHandle(hgpu);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_DeleteDeviceList(devices);
|
|
||||||
TF_DeleteTensor(t);
|
TF_DeleteTensor(t);
|
||||||
TFE_DeleteTensorHandle(hcpu);
|
TFE_DeleteTensorHandle(hcpu);
|
||||||
TFE_DeleteContext(ctx, status.get());
|
TFE_DeleteContext(ctx, status.get());
|
||||||
@ -337,22 +354,15 @@ TEST(CAPI, TensorHandleSilentCopyLocal) {
|
|||||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
const int num_devices = TF_DeviceListCount(devices);
|
|
||||||
|
|
||||||
// Disable the test if no GPU is present.
|
// Disable the test if no GPU is present.
|
||||||
if (num_devices > 1) {
|
string gpu_device_name;
|
||||||
const int device_to_use = 1;
|
if (GetGPUDeviceName(ctx, &gpu_device_name)) {
|
||||||
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
|
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||||
|
|
||||||
TFE_TensorHandle* hgpu =
|
|
||||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
|
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||||
TFE_OpSetDevice(matmul, name.c_str(), status.get());
|
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
int num_retvals = 1;
|
int num_retvals = 1;
|
||||||
@ -363,13 +373,43 @@ TEST(CAPI, TensorHandleSilentCopyLocal) {
|
|||||||
TFE_DeleteTensorHandle(hgpu);
|
TFE_DeleteTensorHandle(hgpu);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_DeleteDeviceList(devices);
|
|
||||||
TF_DeleteTensor(t);
|
TF_DeleteTensor(t);
|
||||||
TFE_DeleteTensorHandle(hcpu);
|
TFE_DeleteTensorHandle(hcpu);
|
||||||
TFE_DeleteContext(ctx, status.get());
|
TFE_DeleteContext(ctx, status.get());
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, SetAndGetOpDevices) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||||
|
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||||
|
|
||||||
|
// Disable the test if no GPU is present.
|
||||||
|
string gpu_device_name;
|
||||||
|
if (GetGPUDeviceName(ctx, &gpu_device_name)) {
|
||||||
|
TFE_OpSetDevice(matmul, "GPU:0", status);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||||
|
const char* device_name = TFE_OpGetDevice(matmul, status);
|
||||||
|
ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr);
|
||||||
|
|
||||||
|
TFE_OpSetDevice(matmul, "CPU:0", status);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||||
|
device_name = TFE_OpGetDevice(matmul, status);
|
||||||
|
ASSERT_TRUE(strstr(device_name, "CPU:0") != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_DeleteOp(matmul);
|
||||||
|
TFE_DeleteTensorHandle(m);
|
||||||
|
TFE_DeleteContext(ctx, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CAPI, Execute) {
|
TEST(CAPI, Execute) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
@ -86,10 +86,9 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
|
Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
|
||||||
TF_AttrType* out, unsigned char* is_list) {
|
TF_AttrType* out, unsigned char* is_list) {
|
||||||
CHECK(m);
|
auto* t = gtl::FindOrNull(m, attr_name);
|
||||||
auto* t = gtl::FindOrNull(*m, attr_name);
|
|
||||||
if (t == nullptr) {
|
if (t == nullptr) {
|
||||||
return errors::InvalidArgument("Attribute '", attr_name,
|
return errors::InvalidArgument("Attribute '", attr_name,
|
||||||
"' does not exist for this operation");
|
"' does not exist for this operation");
|
||||||
@ -173,14 +172,14 @@ void CombineUnordered(const tensorflow::Fprint128& a,
|
|||||||
b->high64 += a.high64;
|
b->high64 += a.high64;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s,
|
inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
|
||||||
const tensorflow::Fprint128& b) {
|
const tensorflow::Fprint128& b) {
|
||||||
// TODO(agarwal): avoid ToString().
|
// TODO(agarwal): avoid ToString().
|
||||||
tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
|
tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
|
||||||
return FingerprintCat128(a, b);
|
return FingerprintCat128(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) {
|
inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
|
||||||
return CacheKeyHelper(s, {b, b});
|
return CacheKeyHelper(s, {b, b});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ typedef std::unordered_map<string, uint32> AttrTypeMap;
|
|||||||
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
|
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
|
||||||
|
|
||||||
// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
|
// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
|
||||||
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
|
Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
|
||||||
TF_AttrType* out, unsigned char* is_list);
|
TF_AttrType* out, unsigned char* is_list);
|
||||||
|
|
||||||
// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
|
// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
|
||||||
|
@ -63,17 +63,17 @@ TEST(AttrTypeMap, Lookup) {
|
|||||||
|
|
||||||
TF_AttrType t;
|
TF_AttrType t;
|
||||||
unsigned char is_list = 1;
|
unsigned char is_list = 1;
|
||||||
s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
|
s = AttrTypeByName(*m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
|
||||||
EXPECT_FALSE(s.ok());
|
EXPECT_FALSE(s.ok());
|
||||||
EXPECT_NE(is_list, 0);
|
EXPECT_NE(is_list, 0);
|
||||||
s = AttrTypeByName(m, "transpose_a", &t, &is_list);
|
s = AttrTypeByName(*m, "transpose_a", &t, &is_list);
|
||||||
ASSERT_TRUE(s.ok()) << s;
|
ASSERT_TRUE(s.ok()) << s;
|
||||||
EXPECT_EQ(TF_ATTR_BOOL, t);
|
EXPECT_EQ(TF_ATTR_BOOL, t);
|
||||||
EXPECT_EQ(is_list, 0);
|
EXPECT_EQ(is_list, 0);
|
||||||
|
|
||||||
s = AttrTypeMapForOp("Squeeze", &m);
|
s = AttrTypeMapForOp("Squeeze", &m);
|
||||||
ASSERT_TRUE(s.ok()) << s;
|
ASSERT_TRUE(s.ok()) << s;
|
||||||
s = AttrTypeByName(m, "squeeze_dims", &t, &is_list);
|
s = AttrTypeByName(*m, "squeeze_dims", &t, &is_list);
|
||||||
ASSERT_TRUE(s.ok()) << s;
|
ASSERT_TRUE(s.ok()) << s;
|
||||||
EXPECT_EQ(TF_ATTR_INT, t);
|
EXPECT_EQ(TF_ATTR_INT, t);
|
||||||
EXPECT_NE(is_list, 0);
|
EXPECT_NE(is_list, 0);
|
||||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
|||||||
// Language-agnostic gradient tape. Does not perform backpropagation, just
|
// Language-agnostic gradient tape. Does not perform backpropagation, just
|
||||||
// maintains the data structures required to do so.
|
// maintains the data structures required to do so.
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <unordered_set>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -54,11 +54,11 @@ struct OpTapeEntry {
|
|||||||
// Map from tensor_id to internally-defined operation-id of the operation which
|
// Map from tensor_id to internally-defined operation-id of the operation which
|
||||||
// produced this tensor. A value of -1 means that the tensor was directly
|
// produced this tensor. A value of -1 means that the tensor was directly
|
||||||
// watched and not the result of any operation in the tape.
|
// watched and not the result of any operation in the tape.
|
||||||
using TensorTape = std::unordered_map<int64, int64>;
|
using TensorTape = gtl::FlatMap<int64, int64>;
|
||||||
|
|
||||||
// Map from operation-id to tape entry.
|
// Map from operation-id to tape entry.
|
||||||
template <typename BackwardFunction>
|
template <typename BackwardFunction>
|
||||||
using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
|
using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
|
||||||
|
|
||||||
// Operations the tape needs to perform on tensors to do backpropagation. Named
|
// Operations the tape needs to perform on tensors to do backpropagation. Named
|
||||||
// "vspace" because a subset of these are related to a vector space, such as
|
// "vspace" because a subset of these are related to a vector space, such as
|
||||||
@ -159,7 +159,7 @@ class GradientTape {
|
|||||||
|
|
||||||
// Map from tensor id to number of remaining usages (i.e. how many entries in
|
// Map from tensor id to number of remaining usages (i.e. how many entries in
|
||||||
// the tape refer to it); to aid in tape garbage collection.
|
// the tape refer to it); to aid in tape garbage collection.
|
||||||
std::unordered_map<int64, int64> tensor_usage_;
|
gtl::FlatMap<int64, int64> tensor_usage_;
|
||||||
|
|
||||||
// If false, all activations are deleted in the first call to ComputeGradient.
|
// If false, all activations are deleted in the first call to ComputeGradient.
|
||||||
// Else, only when this is destructed.
|
// Else, only when this is destructed.
|
||||||
@ -286,11 +286,11 @@ struct BackpropInitialState {
|
|||||||
|
|
||||||
// Map from tensor ID to how many references still exist for this tensor in
|
// Map from tensor ID to how many references still exist for this tensor in
|
||||||
// the tape.
|
// the tape.
|
||||||
std::unordered_map<int64, int64> tensor_usage_counts;
|
gtl::FlatMap<int64, int64> tensor_usage_counts;
|
||||||
|
|
||||||
// Maps from op ID to how many output tensors of this op still need to have
|
// Maps from op ID to how many output tensors of this op still need to have
|
||||||
// their gradients computed.
|
// their gradients computed.
|
||||||
std::unordered_map<int64, int64> op_missing_tensor;
|
gtl::FlatMap<int64, int64> op_missing_tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
// If `persistent_tape` is true, op_tape is not changed and none of the
|
// If `persistent_tape` is true, op_tape is not changed and none of the
|
||||||
@ -301,8 +301,8 @@ struct BackpropInitialState {
|
|||||||
template <typename BackwardFunction>
|
template <typename BackwardFunction>
|
||||||
BackpropInitialState<BackwardFunction> PrepareBackprop(
|
BackpropInitialState<BackwardFunction> PrepareBackprop(
|
||||||
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
|
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
|
||||||
OpTape<BackwardFunction>* op_tape,
|
OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
|
||||||
const std::unordered_set<int64>& sources_set, bool persistent_tape) {
|
bool persistent_tape) {
|
||||||
std::vector<int64> tensor_stack;
|
std::vector<int64> tensor_stack;
|
||||||
tensor_stack.reserve(target.size());
|
tensor_stack.reserve(target.size());
|
||||||
for (auto t : target) {
|
for (auto t : target) {
|
||||||
@ -362,7 +362,7 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
|
|||||||
template <typename BackwardFunction>
|
template <typename BackwardFunction>
|
||||||
std::vector<int64> InitialStack(
|
std::vector<int64> InitialStack(
|
||||||
const OpTape<BackwardFunction>& op_tape,
|
const OpTape<BackwardFunction>& op_tape,
|
||||||
const std::unordered_map<int64, int64>& op_missing_tensor) {
|
const gtl::FlatMap<int64, int64>& op_missing_tensor) {
|
||||||
std::vector<int64> result;
|
std::vector<int64> result;
|
||||||
for (auto& op_entry : op_tape) {
|
for (auto& op_entry : op_tape) {
|
||||||
if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
|
if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
|
||||||
@ -373,13 +373,13 @@ std::vector<int64> InitialStack(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Gradient, typename BackwardFunction>
|
template <typename Gradient, typename BackwardFunction>
|
||||||
Status InitialGradients(
|
Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
|
||||||
const VSpace<Gradient, BackwardFunction>& vspace,
|
|
||||||
gtl::ArraySlice<int64> target_tensor_ids,
|
gtl::ArraySlice<int64> target_tensor_ids,
|
||||||
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
|
gtl::ArraySlice<Gradient*> output_gradients,
|
||||||
|
const TensorTape& tensor_tape,
|
||||||
const OpTape<BackwardFunction>& op_tape,
|
const OpTape<BackwardFunction>& op_tape,
|
||||||
const std::unordered_map<int64, int64>& tensor_usage_counts,
|
const gtl::FlatMap<int64, int64>& tensor_usage_counts,
|
||||||
std::unordered_map<int64, std::vector<Gradient*>>* result) {
|
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
|
||||||
for (int i = 0; i < target_tensor_ids.size(); ++i) {
|
for (int i = 0; i < target_tensor_ids.size(); ++i) {
|
||||||
const int64 id = target_tensor_ids[i];
|
const int64 id = target_tensor_ids[i];
|
||||||
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
|
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
|
||||||
@ -441,13 +441,13 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
|||||||
gtl::ArraySlice<int64> source_tensor_ids,
|
gtl::ArraySlice<int64> source_tensor_ids,
|
||||||
gtl::ArraySlice<Gradient*> output_gradients,
|
gtl::ArraySlice<Gradient*> output_gradients,
|
||||||
std::vector<Gradient*>* result) {
|
std::vector<Gradient*>* result) {
|
||||||
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
|
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
|
||||||
source_tensor_ids.end());
|
source_tensor_ids.end());
|
||||||
BackpropInitialState<BackwardFunction> state = PrepareBackprop(
|
BackpropInitialState<BackwardFunction> state = PrepareBackprop(
|
||||||
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
|
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
|
||||||
std::vector<int64> op_stack =
|
std::vector<int64> op_stack =
|
||||||
InitialStack(state.op_tape, state.op_missing_tensor);
|
InitialStack(state.op_tape, state.op_missing_tensor);
|
||||||
std::unordered_map<int64, std::vector<Gradient*>> gradients;
|
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
|
||||||
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
|
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
|
||||||
tensor_tape_, state.op_tape,
|
tensor_tape_, state.op_tape,
|
||||||
state.tensor_usage_counts, &gradients);
|
state.tensor_usage_counts, &gradients);
|
||||||
@ -463,7 +463,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
|||||||
cleanup();
|
cleanup();
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
std::unordered_map<int64, int64> gradients_size;
|
gtl::FlatMap<int64, int64> gradients_size;
|
||||||
// TODO(apassos) multiple threads could be dequeuing from op_stack at the same
|
// TODO(apassos) multiple threads could be dequeuing from op_stack at the same
|
||||||
// time, for better CPU backprop performance.
|
// time, for better CPU backprop performance.
|
||||||
VLOG(1) << "Initial stack:";
|
VLOG(1) << "Initial stack:";
|
||||||
@ -472,8 +472,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
|||||||
VLOG(1) << " " << t;
|
VLOG(1) << " " << t;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::unordered_map<string, std::unordered_set<int>>
|
gtl::FlatMap<string, gtl::FlatSet<int>> functions_accept_none_for_indices({
|
||||||
functions_accept_none_for_indices({
|
|
||||||
{"SoftmaxCrossEntropyWithLogits", {1}},
|
{"SoftmaxCrossEntropyWithLogits", {1}},
|
||||||
{"FusedBatchNorm", {1, 2, 3, 4}},
|
{"FusedBatchNorm", {1, 2, 3, 4}},
|
||||||
});
|
});
|
||||||
|
@ -433,6 +433,7 @@ tf_gen_op_wrappers_cc(
|
|||||||
"linalg_ops",
|
"linalg_ops",
|
||||||
"logging_ops",
|
"logging_ops",
|
||||||
"lookup_ops",
|
"lookup_ops",
|
||||||
|
"manip_ops",
|
||||||
"math_ops",
|
"math_ops",
|
||||||
"nn_ops",
|
"nn_ops",
|
||||||
"no_op",
|
"no_op",
|
||||||
|
@ -96,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
|
|||||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||||
const SessionOptions& session_options,
|
const SessionOptions& session_options,
|
||||||
std::unique_ptr<Session>* session) {
|
std::unique_ptr<Session>* session) {
|
||||||
session->reset(NewSession(session_options));
|
Session* session_p = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
|
||||||
|
session->reset(session_p);
|
||||||
return (*session)->Create(meta_graph_def.graph_def());
|
return (*session)->Create(meta_graph_def.graph_def());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
|
|||||||
<< st.error_message();
|
<< st.error_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LoaderTest, SessionCreationFailure) {
|
||||||
|
SavedModelBundle bundle;
|
||||||
|
// Use invalid SessionOptions to cause session creation to fail. Default
|
||||||
|
// options work, so provide an invalid value for the target field.
|
||||||
|
SessionOptions session_options;
|
||||||
|
constexpr char kInvalidTarget[] = "invalid target";
|
||||||
|
session_options.target = kInvalidTarget;
|
||||||
|
RunOptions run_options;
|
||||||
|
|
||||||
|
const string export_dir =
|
||||||
|
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||||
|
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||||
|
{kSavedModelTagServe}, &bundle);
|
||||||
|
EXPECT_FALSE(st.ok());
|
||||||
|
EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget))
|
||||||
|
<< st.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(LoaderTest, PbtxtFormat) {
|
TEST_F(LoaderTest, PbtxtFormat) {
|
||||||
SavedModelBundle bundle;
|
SavedModelBundle bundle;
|
||||||
SessionOptions session_options;
|
SessionOptions session_options;
|
||||||
|
@ -23,7 +23,6 @@ cc_library(
|
|||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:tensorflow",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
To use from your BUILD file, add the following line to load the macro:
|
To use from your BUILD file, add the following line to load the macro:
|
||||||
|
|
||||||
load("@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||||
|
|
||||||
Then call the macro like this:
|
Then call the macro like this:
|
||||||
|
|
||||||
@ -16,14 +16,15 @@ tf_library(
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_android", "tf_copts")
|
load("//tensorflow:tensorflow.bzl",
|
||||||
|
"if_android", "tf_cc_test", "tf_copts")
|
||||||
|
|
||||||
def tf_library(name, graph, config,
|
def tf_library(name, graph, config,
|
||||||
freeze_checkpoint=None, freeze_saver=None,
|
freeze_checkpoint=None, freeze_saver=None,
|
||||||
cpp_class=None, gen_test=True, gen_benchmark=True,
|
cpp_class=None, gen_test=True, gen_benchmark=True,
|
||||||
visibility=None, testonly=None,
|
visibility=None, testonly=None,
|
||||||
tfcompile_flags=None,
|
tfcompile_flags=None,
|
||||||
tfcompile_tool="@org_tensorflow//tensorflow/compiler/aot:tfcompile",
|
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
|
||||||
include_standard_runtime_deps=True, deps=None, tags=None):
|
include_standard_runtime_deps=True, deps=None, tags=None):
|
||||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||||
|
|
||||||
@ -119,9 +120,9 @@ def tf_library(name, graph, config,
|
|||||||
out_nodes_file,
|
out_nodes_file,
|
||||||
] + freeze_saver_srcs,
|
] + freeze_saver_srcs,
|
||||||
outs=[freeze_file],
|
outs=[freeze_file],
|
||||||
cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" +
|
cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
|
||||||
freeze_args),
|
freeze_args),
|
||||||
tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"],
|
tools=["//tensorflow/python/tools:freeze_graph"],
|
||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
tfcompile_graph = freeze_file
|
tfcompile_graph = freeze_file
|
||||||
@ -213,22 +214,22 @@ def tf_library(name, graph, config,
|
|||||||
# These deps are required by all tf_library targets even if
|
# These deps are required by all tf_library targets even if
|
||||||
# include_standard_runtime_deps is False. Without them, the
|
# include_standard_runtime_deps is False. Without them, the
|
||||||
# generated code will fail to compile.
|
# generated code will fail to compile.
|
||||||
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
|
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
|
||||||
"@org_tensorflow//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
] + (need_xla_data_proto and [
|
] + (need_xla_data_proto and [
|
||||||
# If we're generating the program shape, we must depend on the proto.
|
# If we're generating the program shape, we must depend on the proto.
|
||||||
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
] or []) + (include_standard_runtime_deps and [
|
] or []) + (include_standard_runtime_deps and [
|
||||||
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
|
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
|
||||||
"@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
||||||
"@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
|
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
|
||||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
|
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
|
||||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
|
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
|
||||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
|
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
|
||||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
|
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
] or []) + (deps or []),
|
] or []) + (deps or []),
|
||||||
tags=tags,
|
tags=tags,
|
||||||
@ -254,28 +255,32 @@ def tf_library(name, graph, config,
|
|||||||
name=("gen_" + test_name),
|
name=("gen_" + test_name),
|
||||||
testonly=1,
|
testonly=1,
|
||||||
srcs=[
|
srcs=[
|
||||||
"@org_tensorflow//tensorflow/compiler/aot:test.cc",
|
"//tensorflow/compiler/aot:test.cc",
|
||||||
header_file,
|
header_file,
|
||||||
],
|
],
|
||||||
outs=[test_file],
|
outs=[test_file],
|
||||||
cmd=("sed " + sed_replace +
|
cmd=("sed " + sed_replace +
|
||||||
" $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " +
|
" $(location //tensorflow/compiler/aot:test.cc) " +
|
||||||
"> $(OUTS)"),
|
"> $(OUTS)"),
|
||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The cc_test rule for the generated code.
|
# The cc_test rule for the generated code. To ensure that this works
|
||||||
native.cc_test(
|
# reliably across build configurations, we must use tf_cc_test instead of
|
||||||
|
# native.cc_test. This is related to how we build
|
||||||
|
# //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
|
||||||
|
# for more details.
|
||||||
|
tf_cc_test(
|
||||||
name=test_name,
|
name=test_name,
|
||||||
srcs=[test_file],
|
srcs=[test_file],
|
||||||
deps=[
|
deps=[
|
||||||
":" + name,
|
":" + name,
|
||||||
"@org_tensorflow//tensorflow/compiler/aot:runtime",
|
"//tensorflow/compiler/aot:runtime",
|
||||||
"@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main",
|
"//tensorflow/compiler/aot:tf_library_test_main",
|
||||||
"@org_tensorflow//tensorflow/compiler/xla:executable_run_options",
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@org_tensorflow//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@org_tensorflow//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
@ -283,7 +288,7 @@ def tf_library(name, graph, config,
|
|||||||
if gen_benchmark:
|
if gen_benchmark:
|
||||||
benchmark_name = name + "_benchmark"
|
benchmark_name = name + "_benchmark"
|
||||||
benchmark_file = benchmark_name + ".cc"
|
benchmark_file = benchmark_name + ".cc"
|
||||||
benchmark_main = ("@org_tensorflow//tensorflow/compiler/aot:" +
|
benchmark_main = ("//tensorflow/compiler/aot:" +
|
||||||
"benchmark_main.template")
|
"benchmark_main.template")
|
||||||
|
|
||||||
# Rule to rewrite benchmark.cc to produce the benchmark_file.
|
# Rule to rewrite benchmark.cc to produce the benchmark_file.
|
||||||
@ -301,7 +306,9 @@ def tf_library(name, graph, config,
|
|||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The cc_benchmark rule for the generated code.
|
# The cc_benchmark rule for the generated code. This does not need the
|
||||||
|
# tf_cc_binary since we (by deliberate design) do not depend on
|
||||||
|
# //tensorflow/core:lib.
|
||||||
#
|
#
|
||||||
# Note: to get smaller size on android for comparison, compile with:
|
# Note: to get smaller size on android for comparison, compile with:
|
||||||
# --copt=-fvisibility=hidden
|
# --copt=-fvisibility=hidden
|
||||||
@ -315,12 +322,12 @@ def tf_library(name, graph, config,
|
|||||||
linkopts = if_android(["-pie", "-s"]),
|
linkopts = if_android(["-pie", "-s"]),
|
||||||
deps=[
|
deps=[
|
||||||
":" + name,
|
":" + name,
|
||||||
"@org_tensorflow//tensorflow/compiler/aot:benchmark",
|
"//tensorflow/compiler/aot:benchmark",
|
||||||
"@org_tensorflow//tensorflow/compiler/aot:runtime",
|
"//tensorflow/compiler/aot:runtime",
|
||||||
"@org_tensorflow//tensorflow/compiler/xla:executable_run_options",
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
] + if_android([
|
] + if_android([
|
||||||
"@org_tensorflow//tensorflow/compiler/aot:benchmark_extra_android",
|
"//tensorflow/compiler/aot:benchmark_extra_android",
|
||||||
]),
|
]),
|
||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
@ -330,11 +337,11 @@ def target_llvm_triple():
|
|||||||
# TODO(toddw): Add target_triple for other targets. For details see:
|
# TODO(toddw): Add target_triple for other targets. For details see:
|
||||||
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
|
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
|
||||||
return select({
|
return select({
|
||||||
"@org_tensorflow//tensorflow:android_armeabi": "armv5-none-android",
|
"//tensorflow:android_armeabi": "armv5-none-android",
|
||||||
"@org_tensorflow//tensorflow:android_arm": "armv7-none-android",
|
"//tensorflow:android_arm": "armv7-none-android",
|
||||||
"@org_tensorflow//tensorflow:android_arm64": "aarch64-none-android",
|
"//tensorflow:android_arm64": "aarch64-none-android",
|
||||||
"@org_tensorflow//tensorflow:android_x86": "i686-none-android",
|
"//tensorflow:android_x86": "i686-none-android",
|
||||||
"@org_tensorflow//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||||
"@org_tensorflow//tensorflow:darwin": "x86_64-none-darwin",
|
"//tensorflow:darwin": "x86_64-none-darwin",
|
||||||
"//conditions:default": "x86_64-pc-linux",
|
"//conditions:default": "x86_64-pc-linux",
|
||||||
})
|
})
|
||||||
|
@ -30,12 +30,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
|
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/graph_def_util.h"
|
#include "tensorflow/core/framework/graph_def_util.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
#include "tensorflow/core/graph/tensor_id.h"
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
@ -141,8 +143,7 @@ struct NodeSlot {
|
|||||||
// everything to use it.
|
// everything to use it.
|
||||||
static const char* const kArgOp = "_Arg";
|
static const char* const kArgOp = "_Arg";
|
||||||
static const char* const kRetValOp = "_Retval";
|
static const char* const kRetValOp = "_Retval";
|
||||||
static const char* const kSendToHostOp = "_XlaSendToHost";
|
static const char* const kHostComputeOp = "_XlaHostCompute";
|
||||||
static const char* const kRecvFromHostOp = "_XlaRecvFromHost";
|
|
||||||
static const char* const kSendFromHostOp = "_XlaSendFromHost";
|
static const char* const kSendFromHostOp = "_XlaSendFromHost";
|
||||||
static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
|
static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
|
||||||
|
|
||||||
@ -171,7 +172,8 @@ class Encapsulator {
|
|||||||
|
|
||||||
// Write a copy of the input graph to 'graph_out', where the subgraphs are
|
// Write a copy of the input graph to 'graph_out', where the subgraphs are
|
||||||
// replaced with calls to the new functions.
|
// replaced with calls to the new functions.
|
||||||
Status BuildOutputGraph(bool parallel_checking, Graph* graph_out);
|
Status BuildOutputGraph(bool parallel_checking, Graph* graph_out,
|
||||||
|
FunctionLibraryDefinition* library);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// A subgraph of the input, all marked with a common 'group_attribute'
|
// A subgraph of the input, all marked with a common 'group_attribute'
|
||||||
@ -201,21 +203,29 @@ class Encapsulator {
|
|||||||
// .. .
|
// .. .
|
||||||
// RAH --> C --> SFH
|
// RAH --> C --> SFH
|
||||||
//
|
//
|
||||||
// The compiled cluster is as follows. STH is a SendToHost node which is the
|
// The compiled cluster is as follows. HC is a HostCompute node which is the
|
||||||
// source of a channel to the RAH node above. RFH is a RecvFromHost node which
|
// source of a channel to the RAH node above and the destination of a channel
|
||||||
// is the destination of a channel from the SFH node above. There is a control
|
// from the SFH node above.
|
||||||
// edge that ensures RFH follows STH, which is used in shape inference to
|
|
||||||
// ensure that the shapes on the STH host channel are known before the RFH
|
|
||||||
// channel is compiled.
|
|
||||||
//
|
//
|
||||||
// Arg --> B --> STH ..> RFH --> D --> Retval
|
// Arg --> B --> HC --> D --> Retval
|
||||||
//
|
//
|
||||||
// The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most
|
// The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is
|
||||||
// one RAH and SFH in each compiled cluster. This design is preferred over
|
// at most one RAH and SFH in each outside_compilation cluster. This design is
|
||||||
// adding separate Arg/Retval nodes for each transmitted value because it
|
// preferred over adding separate Arg/Retval nodes for each transmitted value
|
||||||
// simplifies the host code that would like to limit communication between
|
// because it allows optimizations to the host code that would like to limit
|
||||||
// host and device and, e.g., raise only one interrupt per channel rather than
|
// communication between host and device and, e.g., raise only one interrupt
|
||||||
// one per transmitted value.
|
// per channel rather than one per transmitted value.
|
||||||
|
//
|
||||||
|
// The shapes of the outputs from the HC node in general cannot be determined
|
||||||
|
// until the shapes of its inputs are known at compile time, since e.g.,
|
||||||
|
// above, the shape of C's outputs aren't known until the shape of its inputs
|
||||||
|
// are known. If the shapes of the HC's outputs can be determined during the
|
||||||
|
// rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal
|
||||||
|
// graph is stored in the shape_inference_graph attr. This graph can be used
|
||||||
|
// when compiling the HC Op to determined the shape of the SFH inputs given
|
||||||
|
// the shapes of any ancestor RAH outputs. If it can be determined that the
|
||||||
|
// shape of the SFH inputs will not be inferrable even once the shapes of the
|
||||||
|
// RAH outputs are known, an error is returned by the rewriter.
|
||||||
class Subgraph {
|
class Subgraph {
|
||||||
public:
|
public:
|
||||||
// Creates a graph to build the subgraph in, if it doesn't already exist,
|
// Creates a graph to build the subgraph in, if it doesn't already exist,
|
||||||
@ -246,6 +256,10 @@ class Encapsulator {
|
|||||||
const std::unordered_map<const Node*, Node*>& node_images,
|
const std::unordered_map<const Node*, Node*>& node_images,
|
||||||
Graph* graph_out);
|
Graph* graph_out);
|
||||||
|
|
||||||
|
// Returns the names of all the outside_compilation subgraphs in this
|
||||||
|
// Subgraph.
|
||||||
|
void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const;
|
||||||
|
|
||||||
// Returns the Node that inputs to the function should be wired up to.
|
// Returns the Node that inputs to the function should be wired up to.
|
||||||
Node* GetCallNodeForInputs() const;
|
Node* GetCallNodeForInputs() const;
|
||||||
|
|
||||||
@ -305,15 +319,9 @@ class Encapsulator {
|
|||||||
void RecordOutsideCompilationOutputOrControl(
|
void RecordOutsideCompilationOutputOrControl(
|
||||||
const string& outside_compilation_id, const Edge* edge);
|
const string& outside_compilation_id, const Edge* edge);
|
||||||
|
|
||||||
// Adds the SendToHost nodes for each outside_compilation subgraph once the
|
// Adds the HostCompute nodes for each outside_compilation subgraph.
|
||||||
// edges have all been recorded via RecordOutsideCompilationInputOrControl.
|
Status AddHostComputes(
|
||||||
Status AddSendsToOutsideCompilation(
|
const string& subgraph_name,
|
||||||
const std::unordered_map<const Node*, Node*>& node_images);
|
|
||||||
|
|
||||||
// Adds the RecvFromHost nodes for each outside_compilation subgraph once
|
|
||||||
// the edges have all been recorded via
|
|
||||||
// RecordOutsideCompilationOutputOrControl.
|
|
||||||
Status AddRecvsFromOutsideCompilation(
|
|
||||||
const std::unordered_map<const Node*, Node*>& node_images);
|
const std::unordered_map<const Node*, Node*>& node_images);
|
||||||
|
|
||||||
// Creates the sequencer node if it doesn't exist, adding it to graph_out.
|
// Creates the sequencer node if it doesn't exist, adding it to graph_out.
|
||||||
@ -323,10 +331,16 @@ class Encapsulator {
|
|||||||
// all the downstream nodes of call_node_outputs.
|
// all the downstream nodes of call_node_outputs.
|
||||||
void ConnectSequencerToOutputs(Graph* graph_out);
|
void ConnectSequencerToOutputs(Graph* graph_out);
|
||||||
|
|
||||||
|
Status AddShapeInferenceInfo(
|
||||||
|
const string& outside_compilation_subgraph_name,
|
||||||
|
const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
|
||||||
|
|
||||||
|
Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct OutsideCompilationSubgraph {
|
struct OutsideCompilationSubgraph {
|
||||||
// Map from source (producer node/slot) tensors in the original graph to
|
// Map from source (producer node/slot) tensors in the original graph to
|
||||||
// input index (slot number in the SendToHost/RecvAtHost nodes that will
|
// input index (slot number in the HostCompute/RecvAtHost nodes that will
|
||||||
// be created) for the outside_compilation subgraph.
|
// be created) for the outside_compilation subgraph.
|
||||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
|
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
|
||||||
|
|
||||||
@ -335,14 +349,14 @@ class Encapsulator {
|
|||||||
// outside_compilation subgraph. These are recorded by
|
// outside_compilation subgraph. These are recorded by
|
||||||
// RecordOutsideCompilationInputOrControl while walking all the subgraph
|
// RecordOutsideCompilationInputOrControl while walking all the subgraph
|
||||||
// edges, and lifted control edges within the subgraph are added by
|
// edges, and lifted control edges within the subgraph are added by
|
||||||
// AddSendsToOutsideCompilation once the _SendToHost node has been
|
// AddSendsToOutsideCompilation once the _HostCompute node has been
|
||||||
// created. The matching control edge from _RecvAtHost to the
|
// created. The matching control edge from _RecvAtHost to the
|
||||||
// destination is added by CopyEdgeToOutputGraph.
|
// destination is added by CopyEdgeToOutputGraph.
|
||||||
std::unordered_set<const Node*> control_inputs;
|
std::unordered_set<const Node*> control_inputs;
|
||||||
|
|
||||||
// Maps from source (producer node/slot) and destination (consumer
|
// Maps from source (producer node/slot) and destination (consumer
|
||||||
// node/slot) tensors in the original graph to output index (slot number
|
// node/slot) tensors in the original graph to output index (slot number
|
||||||
// in the SendFromHost/RecvFromHost nodes that will be created) for the
|
// in the SendFromHost/HostCompute nodes that will be created) for the
|
||||||
// outside_compilation subgraph.
|
// outside_compilation subgraph.
|
||||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
|
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
|
||||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
|
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
|
||||||
@ -352,13 +366,13 @@ class Encapsulator {
|
|||||||
// containing compiled subgraph. These are recorded by
|
// containing compiled subgraph. These are recorded by
|
||||||
// RecordOutsideCompilationOutputOrControl while walking all the subgraph
|
// RecordOutsideCompilationOutputOrControl while walking all the subgraph
|
||||||
// edges, and lifted control edges within the subgraph are added by
|
// edges, and lifted control edges within the subgraph are added by
|
||||||
// AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been
|
// AddRecvsFromToOutsideCompilation once the _HostCompute node has been
|
||||||
// created. The matching control edge from the source to _SendFromHost to
|
// created. The matching control edge from the source to _SendFromHost to
|
||||||
// the destination is added by CopyEdgeToOutputGraph.
|
// the destination is added by CopyEdgeToOutputGraph.
|
||||||
std::unordered_set<const Node*> control_outputs;
|
std::unordered_set<const Node*> control_outputs;
|
||||||
|
|
||||||
// _SendToHost node in the subgraph. Not owned.
|
// Name of the _HostCompute node in the subgraph.
|
||||||
Node* send_to_host = nullptr;
|
string host_compute_name;
|
||||||
|
|
||||||
// _RecvAtHost node in the output graph. Not owned.
|
// _RecvAtHost node in the output graph. Not owned.
|
||||||
Node* recv_at_host = nullptr;
|
Node* recv_at_host = nullptr;
|
||||||
@ -516,6 +530,59 @@ class Encapsulator {
|
|||||||
const std::unordered_map<const Node*, Node*>& node_images,
|
const std::unordered_map<const Node*, Node*>& node_images,
|
||||||
bool parallel_checking, Graph* graph_out);
|
bool parallel_checking, Graph* graph_out);
|
||||||
|
|
||||||
|
// Constructs a minimal shape inference graph that can be used to determine
|
||||||
|
// the shape of send_node at the time that the subgraph is compiled.
|
||||||
|
// recv_at_host_nodes contains the names of all the recv_at_host nodes that
|
||||||
|
// send_node might depend on. These recv_at_host nodes have shapes that are
|
||||||
|
// not known during the rewrite pass, but will be known at compile time.
|
||||||
|
//
|
||||||
|
// If the shapes of all the inputs to send_node can be determined during the
|
||||||
|
// rewrite pass, on exit graphdef_out is empty and the shapes are returned in
|
||||||
|
// static_shape_out. Otherwise graphdef_out contains a graph that can be used
|
||||||
|
// for shape inference at compile time, where all the source nodes of the
|
||||||
|
// graph are either constants with known shapes, or nodes named in
|
||||||
|
// recv_at_host_nodes.
|
||||||
|
//
|
||||||
|
// A non-OK status is returned if neither of the above conditions can be
|
||||||
|
// satisfied, e.g., because send_node depends on a node that doesn't have a
|
||||||
|
// registered shape inference function.
|
||||||
|
Status DoStaticShapeInferenceForOutsideCompilationSend(
|
||||||
|
const Graph& graph_in, const ShapeRefiner& shape_refiner,
|
||||||
|
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
||||||
|
FunctionLibraryDefinition* library,
|
||||||
|
std::vector<TensorShapeProto>* static_shape_out,
|
||||||
|
std::unique_ptr<GraphDef>* graphdef_out);
|
||||||
|
|
||||||
|
// Makes a copy of graph containing only nodes that are ancestors of at least
|
||||||
|
// one node in send_from_host_nodes and store it in pruned_graph. On exit
|
||||||
|
// nodes_images contains a mapping from nodes in graph to nodes in
|
||||||
|
// pruned_graph. All functions in the copied graph are inlined.
|
||||||
|
Status MakePrunedGraphCopyAndInline(
|
||||||
|
const Graph& graph, const std::vector<Node*>& sink_nodes,
|
||||||
|
std::unique_ptr<Graph>* pruned_graph,
|
||||||
|
std::unordered_map<const Node*, Node*>* node_images,
|
||||||
|
FunctionLibraryDefinition* library);
|
||||||
|
|
||||||
|
// Makes a copy of graph containing only nodes that are ancestors of a
|
||||||
|
// send_from_host node in an outside_compilation subgraph, and store it in
|
||||||
|
// pruned_graph. Also perform shape inference on the pruned graph, using
|
||||||
|
// shape_refiner. On exit node_images contains a mapping from nodes in graph
|
||||||
|
// to nodes in pruned_graph.
|
||||||
|
Status MakeGraphForOutsideCompilationSends(
|
||||||
|
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
||||||
|
ShapeRefiner* shape_refiner,
|
||||||
|
std::unordered_map<const Node*, Node*>* node_images,
|
||||||
|
FunctionLibraryDefinition* library);
|
||||||
|
|
||||||
|
// Performs static shape inference, as far as possible, for the send_from_host
|
||||||
|
// nodes in each outside_compilation subgraph. Where it is not possible to
|
||||||
|
// determine the shape statically, stores a serialized GraphDef in the
|
||||||
|
// HostCompute 'shape_inference_graph' attr, to be used at compile time for
|
||||||
|
// final inference. If the shapes are known statically they are stored in the
|
||||||
|
// HostCompute 'shapes' attr.
|
||||||
|
Status GetShapeInfoForOutsideCompilationSends(
|
||||||
|
Graph* graph_out, FunctionLibraryDefinition* library);
|
||||||
|
|
||||||
const string group_attribute_;
|
const string group_attribute_;
|
||||||
const string outside_compilation_attribute_;
|
const string outside_compilation_attribute_;
|
||||||
const Graph* graph_in_;
|
const Graph* graph_in_;
|
||||||
@ -682,16 +749,20 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Encapsulator::Subgraph::AddSendsToOutsideCompilation(
|
Status Encapsulator::Subgraph::AddHostComputes(
|
||||||
|
const string& subgraph_name,
|
||||||
const std::unordered_map<const Node*, Node*>& node_images) {
|
const std::unordered_map<const Node*, Node*>& node_images) {
|
||||||
for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
|
for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
|
||||||
const string& oc_subgraph_name = oc_subgraph_iter.first;
|
const string& oc_subgraph_name = oc_subgraph_iter.first;
|
||||||
OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
|
OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
|
||||||
if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
|
if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() ||
|
||||||
// Build a _SendToHost node sending all the args of the appropriate
|
!oc_subgraph.outputs_by_src.empty() ||
|
||||||
// types.
|
!oc_subgraph.control_outputs.empty()) {
|
||||||
std::vector<DataType> dtypes(oc_subgraph.inputs.size(), DT_INVALID);
|
// Build a _HostCompute node.
|
||||||
std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size());
|
std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size());
|
||||||
|
std::vector<DataType> input_dtypes(oc_subgraph.inputs.size(), DT_INVALID);
|
||||||
|
std::vector<DataType> output_dtypes(oc_subgraph.outputs_by_src.size(),
|
||||||
|
DT_INVALID);
|
||||||
|
|
||||||
for (const auto& input_src : oc_subgraph.inputs) {
|
for (const auto& input_src : oc_subgraph.inputs) {
|
||||||
const Node* src_node = input_src.first.node;
|
const Node* src_node = input_src.first.node;
|
||||||
@ -700,94 +771,64 @@ Status Encapsulator::Subgraph::AddSendsToOutsideCompilation(
|
|||||||
int input_index = input_src.second;
|
int input_index = input_src.second;
|
||||||
|
|
||||||
DataType dtype = src_node->output_type(src_slot);
|
DataType dtype = src_node->output_type(src_slot);
|
||||||
dtypes[input_index] = dtype;
|
|
||||||
inputs[input_index].Reset(src_image->name(), src_slot, dtype);
|
inputs[input_index].Reset(src_image->name(), src_slot, dtype);
|
||||||
|
input_dtypes[input_index] = dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef send_def;
|
for (const auto& output : oc_subgraph.outputs_by_src) {
|
||||||
NodeDefBuilder builder(
|
DataType dtype = output.first.dtype;
|
||||||
strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"),
|
int output_index = output.second;
|
||||||
kSendToHostOp);
|
output_dtypes[output_index] = dtype;
|
||||||
builder.Attr("dtypes", dtypes);
|
}
|
||||||
|
|
||||||
|
NodeDef host_compute_def;
|
||||||
|
NodeDefBuilder builder(strings::StrCat("outside_compilation_",
|
||||||
|
oc_subgraph_name, "_host_compute"),
|
||||||
|
kHostComputeOp);
|
||||||
builder.Input(inputs);
|
builder.Input(inputs);
|
||||||
Status s = builder.Finalize(&send_def);
|
builder.Attr("Tinputs", input_dtypes);
|
||||||
|
builder.Attr("Toutputs", output_dtypes);
|
||||||
|
builder.Attr("key",
|
||||||
|
strings::StrCat("host_compute_channel_", subgraph_name, "_",
|
||||||
|
oc_subgraph_name));
|
||||||
|
Status s = builder.Finalize(&host_compute_def);
|
||||||
if (!s.ok()) return s;
|
if (!s.ok()) return s;
|
||||||
|
|
||||||
oc_subgraph.send_to_host = graph_->AddNode(send_def, &s);
|
Node* host_compute = graph_->AddNode(host_compute_def, &s);
|
||||||
if (!s.ok()) return s;
|
if (!s.ok()) return s;
|
||||||
|
oc_subgraph.host_compute_name = host_compute->name();
|
||||||
|
|
||||||
// Connect the _SendToHost node to its producers in the subgraph.
|
// Connect the _HostCompute node to its producers in the subgraph.
|
||||||
for (auto& input_src : oc_subgraph.inputs) {
|
for (auto& input_src : oc_subgraph.inputs) {
|
||||||
const Node* src_node = input_src.first.node;
|
const Node* src_node = input_src.first.node;
|
||||||
Node* src_image = node_images.at(src_node);
|
Node* src_image = node_images.at(src_node);
|
||||||
int src_slot = input_src.first.slot;
|
int src_slot = input_src.first.slot;
|
||||||
int input_index = input_src.second;
|
int input_index = input_src.second;
|
||||||
graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host,
|
graph_->AddEdge(src_image, src_slot, host_compute, input_index);
|
||||||
input_index);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect the _SendToHost node to its control edge producers in the
|
// Connect the _HostCompute node to its control edge producers in the
|
||||||
// subgraph.
|
// subgraph.
|
||||||
for (const auto& src_node : oc_subgraph.control_inputs) {
|
for (const auto& src_node : oc_subgraph.control_inputs) {
|
||||||
Node* src_image = node_images.at(src_node);
|
Node* src_image = node_images.at(src_node);
|
||||||
graph_->AddControlEdge(src_image, oc_subgraph.send_to_host);
|
graph_->AddControlEdge(src_image, host_compute);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
// Connect the consumers in the subgraph to the _HostCompute node.
|
||||||
}
|
|
||||||
|
|
||||||
Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation(
|
|
||||||
const std::unordered_map<const Node*, Node*>& node_images) {
|
|
||||||
for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
|
|
||||||
const string& oc_subgraph_name = oc_subgraph_iter.first;
|
|
||||||
OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
|
|
||||||
if (!oc_subgraph.outputs_by_src.empty() ||
|
|
||||||
!oc_subgraph.control_outputs.empty()) {
|
|
||||||
// Build a _RecvFromHost node producing all the outputs of the appropriate
|
|
||||||
// types.
|
|
||||||
std::vector<DataType> dtypes(oc_subgraph.outputs_by_src.size(),
|
|
||||||
DT_INVALID);
|
|
||||||
|
|
||||||
for (const auto& output : oc_subgraph.outputs_by_src) {
|
|
||||||
DataType dtype = output.first.dtype;
|
|
||||||
int output_index = output.second;
|
|
||||||
dtypes[output_index] = dtype;
|
|
||||||
}
|
|
||||||
|
|
||||||
NodeDef recv_def;
|
|
||||||
NodeDefBuilder builder(
|
|
||||||
strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"),
|
|
||||||
kRecvFromHostOp);
|
|
||||||
builder.Attr("dtypes", dtypes);
|
|
||||||
Status s = builder.Finalize(&recv_def);
|
|
||||||
if (!s.ok()) return s;
|
|
||||||
|
|
||||||
Node* recv = graph_->AddNode(recv_def, &s);
|
|
||||||
if (!s.ok()) return s;
|
|
||||||
|
|
||||||
// Connect the consumers in the subgraph to the _RecvFromHost node.
|
|
||||||
for (const auto& output : oc_subgraph.outputs_by_dst) {
|
for (const auto& output : oc_subgraph.outputs_by_dst) {
|
||||||
const Node* dst_node = output.first.node;
|
const Node* dst_node = output.first.node;
|
||||||
Node* dst_image = node_images.at(dst_node);
|
Node* dst_image = node_images.at(dst_node);
|
||||||
int dst_slot = output.first.slot;
|
int dst_slot = output.first.slot;
|
||||||
int output_index = output.second;
|
int output_index = output.second;
|
||||||
|
|
||||||
graph_->AddEdge(recv, output_index, dst_image, dst_slot);
|
graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect the control edge consumers in the subgraph to the _RecvFromHost
|
// Connect the control edge consumers in the subgraph to the _HostCompute
|
||||||
// node.
|
// node.
|
||||||
for (const auto& dst_node : oc_subgraph.control_outputs) {
|
for (const auto& dst_node : oc_subgraph.control_outputs) {
|
||||||
Node* dst_image = node_images.at(dst_node);
|
Node* dst_image = node_images.at(dst_node);
|
||||||
graph_->AddControlEdge(recv, dst_image);
|
graph_->AddControlEdge(host_compute, dst_image);
|
||||||
}
|
|
||||||
|
|
||||||
// Add a control edge in the subgraph so that the _SendToHost node, if
|
|
||||||
// any, is compiled before the _RecvFromHost node.
|
|
||||||
if (oc_subgraph.send_to_host != nullptr) {
|
|
||||||
graph_->AddControlEdge(oc_subgraph.send_to_host, recv);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -882,6 +923,63 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Encapsulator::Subgraph::AddShapeInferenceInfo(
|
||||||
|
const string& outside_compilation_subgraph_name,
|
||||||
|
const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
|
||||||
|
OutsideCompilationSubgraph& oc_subgraph =
|
||||||
|
outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
|
||||||
|
|
||||||
|
Node* host_compute = nullptr;
|
||||||
|
for (Node* n : graph_->nodes()) {
|
||||||
|
if (n->name() == oc_subgraph.host_compute_name) {
|
||||||
|
host_compute = n;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (host_compute == nullptr) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"After rewriting subgraph ", outside_compilation_subgraph_name,
|
||||||
|
" there is no HostCompute Op for outside compilation subgraph ",
|
||||||
|
oc_subgraph.host_compute_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inference_graph == nullptr) {
|
||||||
|
host_compute->AddAttr("shape_inference_graph", "");
|
||||||
|
host_compute->AddAttr("shapes", shapes);
|
||||||
|
} else {
|
||||||
|
string serialized_graph;
|
||||||
|
if (!inference_graph->SerializeToString(&serialized_graph)) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Failed to serialize graph for outside compilation subgraph ",
|
||||||
|
oc_subgraph.host_compute_name);
|
||||||
|
}
|
||||||
|
host_compute->AddAttr("shape_inference_graph", serialized_graph);
|
||||||
|
host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Encapsulator::Subgraph::ReplaceFunctionDef(
|
||||||
|
FunctionLibraryDefinition* library) {
|
||||||
|
const string& name = call_node_def_.name();
|
||||||
|
|
||||||
|
FunctionDef fdef;
|
||||||
|
TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
|
||||||
|
|
||||||
|
if (VLOG_IS_ON(1)) {
|
||||||
|
VLOG(2) << "Replace function def " << name;
|
||||||
|
dump_graph::DumpGraphToFile(
|
||||||
|
strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
|
||||||
|
library);
|
||||||
|
dump_graph::DumpFunctionDefToFile(
|
||||||
|
strings::StrCat("replace_encapsulate_fdef_", name), fdef);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(library->RemoveFunction(name));
|
||||||
|
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status Encapsulator::Subgraph::BuildParallelCheckOp(
|
Status Encapsulator::Subgraph::BuildParallelCheckOp(
|
||||||
const std::unordered_map<const Node*, Node*>& node_images,
|
const std::unordered_map<const Node*, Node*>& node_images,
|
||||||
Graph* graph_out) {
|
Graph* graph_out) {
|
||||||
@ -980,7 +1078,9 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
|
|||||||
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
|
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
|
||||||
"_", oc_subgraph_name, "_recv"),
|
"_", oc_subgraph_name, "_recv"),
|
||||||
kRecvAtHostOp);
|
kRecvAtHostOp);
|
||||||
builder.Attr("dtypes", dtypes);
|
builder.Attr("Toutputs", dtypes);
|
||||||
|
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
|
||||||
|
"_", oc_subgraph_name));
|
||||||
Status s = builder.Finalize(&recv_def);
|
Status s = builder.Finalize(&recv_def);
|
||||||
if (!s.ok()) return s;
|
if (!s.ok()) return s;
|
||||||
|
|
||||||
@ -1020,7 +1120,9 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
|
|||||||
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
|
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
|
||||||
"_", oc_subgraph_name, "_send"),
|
"_", oc_subgraph_name, "_send"),
|
||||||
kSendFromHostOp);
|
kSendFromHostOp);
|
||||||
builder.Attr("dtypes", dtypes);
|
builder.Attr("Tinputs", dtypes);
|
||||||
|
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
|
||||||
|
"_", oc_subgraph_name));
|
||||||
builder.Input(inputs);
|
builder.Input(inputs);
|
||||||
Status s = builder.Finalize(&send_def);
|
Status s = builder.Finalize(&send_def);
|
||||||
if (!s.ok()) return s;
|
if (!s.ok()) return s;
|
||||||
@ -1062,6 +1164,13 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
|
||||||
|
std::vector<string>* names) const {
|
||||||
|
for (auto& entry : outside_compilation_subgraphs_) {
|
||||||
|
names->push_back(entry.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status Encapsulator::GetFunctionNameAttr(
|
Status Encapsulator::GetFunctionNameAttr(
|
||||||
Node const* node, string* attr, string* outside_compilation_attr) const {
|
Node const* node, string* attr, string* outside_compilation_attr) const {
|
||||||
Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
|
Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
|
||||||
@ -1220,8 +1329,7 @@ Status Encapsulator::SplitIntoSubgraphs() {
|
|||||||
// single input and output node for it.
|
// single input and output node for it.
|
||||||
for (auto& entry : subgraphs_) {
|
for (auto& entry : subgraphs_) {
|
||||||
Subgraph& subgraph = entry.second;
|
Subgraph& subgraph = entry.second;
|
||||||
TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images));
|
TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images));
|
||||||
TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
|
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
|
||||||
@ -1509,8 +1617,346 @@ Status Encapsulator::AddEdgesToOutputGraph(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Encapsulator::BuildOutputGraph(bool parallel_checking,
|
namespace {
|
||||||
|
|
||||||
|
// Adds a dummy Const node to graph_out. The "constant" has the type of
|
||||||
|
// data_type and the shape indicated in 'shape'. The dummy node is not a valid
|
||||||
|
// Const node because it does not have any value defined, but this doesn't
|
||||||
|
// matter because it will only be used subsequently for shape inference. (It
|
||||||
|
// would be possible to add a switch statement over data_type to create a value
|
||||||
|
// for the constant, but that would entail maintaining the logic as new types
|
||||||
|
// are added, and is not necessary.)
|
||||||
|
Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
|
||||||
Graph* graph_out) {
|
Graph* graph_out) {
|
||||||
|
TensorProto dummy_proto;
|
||||||
|
dummy_proto.set_dtype(data_type);
|
||||||
|
*dummy_proto.mutable_tensor_shape() = shape;
|
||||||
|
// Don't set any value field in the proto, since it is only going to be used
|
||||||
|
// for shape inference.
|
||||||
|
|
||||||
|
GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
|
||||||
|
NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
|
||||||
|
options.op_registry());
|
||||||
|
node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
|
||||||
|
return options.FinalizeBuilder(&node_builder);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds a copy of node_in to graph_out and adds the mapping to
|
||||||
|
// copied_node_images.
|
||||||
|
Status CopyShapeInferenceNodeToGraph(
|
||||||
|
Node* node_in, const Node* send_node,
|
||||||
|
const std::unordered_map<Node*, Node*>& dummy_node_images,
|
||||||
|
FunctionLibraryDefinition* library,
|
||||||
|
std::unordered_map<Node*, Node*>* copied_node_images, Graph* graph_out) {
|
||||||
|
// Once all the ancestor nodes have been added to graph_out, add this node
|
||||||
|
// and connect it to its ancestors.
|
||||||
|
Node* node_out = graph_out->CopyNode(node_in);
|
||||||
|
(*copied_node_images)[node_in] = node_out;
|
||||||
|
// Don't bother to build the shape inference graph if there's a node with no
|
||||||
|
// shape inference function, since it would just result in an error later at
|
||||||
|
// compile time.
|
||||||
|
const OpRegistrationData* op_reg_data;
|
||||||
|
TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data));
|
||||||
|
if (op_reg_data->shape_inference_fn == nullptr) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Shape inference is not possible for outside_compilation "
|
||||||
|
"SendFromHost node ",
|
||||||
|
send_node->name(), " because it depends on node ", node_in->name(),
|
||||||
|
" which does not have a shape inference function registered.");
|
||||||
|
}
|
||||||
|
// Add all the edges to the newly copied node.
|
||||||
|
for (const Edge* in_edge : node_in->in_edges()) {
|
||||||
|
if (!in_edge->IsControlEdge()) {
|
||||||
|
Node* src = in_edge->src();
|
||||||
|
const auto iter = dummy_node_images.find(src);
|
||||||
|
if (iter == dummy_node_images.end()) {
|
||||||
|
// The src is a copied node so use the original output port.
|
||||||
|
graph_out->AddEdge((*copied_node_images)[in_edge->src()],
|
||||||
|
in_edge->src_output(), node_out,
|
||||||
|
in_edge->dst_input());
|
||||||
|
} else {
|
||||||
|
// The src is a dummy node so use output port 0.
|
||||||
|
graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
||||||
|
const Graph& graph_in, const ShapeRefiner& shape_refiner,
|
||||||
|
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
||||||
|
FunctionLibraryDefinition* library,
|
||||||
|
std::vector<TensorShapeProto>* static_shape_out,
|
||||||
|
std::unique_ptr<GraphDef>* graphdef_out) {
|
||||||
|
// Maps from nodes in graph_in to nodes in graph_out.
|
||||||
|
//
|
||||||
|
// When an edge has fully defined shape the source node in graph_in is
|
||||||
|
// replaced in graph_out by a dummy constant node. The mapping from nodes
|
||||||
|
// in graph_in to dummy nodes is stored in dummy_node_images.
|
||||||
|
//
|
||||||
|
// When a node in graph_in has at least one ancestor that doesn't have fully
|
||||||
|
// defined shape, it is copied into graph_out. The mapping from nodes in
|
||||||
|
// graph_in to copied nodes is stored in copied_node_images.
|
||||||
|
//
|
||||||
|
// The two types of node are treated differently because, when adding edges to
|
||||||
|
// graph_out, an output from a dummy node always uses port 0, whereas an
|
||||||
|
// output from a copied node uses the same port that was used in graph_in.
|
||||||
|
std::unordered_map<Node*, Node*> dummy_node_images;
|
||||||
|
std::unordered_map<Node*, Node*> copied_node_images;
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
|
||||||
|
graph_out->set_versions(graph_in.versions());
|
||||||
|
static_shape_out->resize(send_node->num_inputs());
|
||||||
|
|
||||||
|
// We don't use the standard ReverseDFS because we want to cut off traversal
|
||||||
|
// whenever we find an output with fully defined shape.
|
||||||
|
// TODO(misard) make this work properly in the presence of control flow.
|
||||||
|
struct Work {
|
||||||
|
Node* node;
|
||||||
|
bool leave; // Are we entering or leaving node?
|
||||||
|
};
|
||||||
|
std::vector<Work> stack({{send_node, false}});
|
||||||
|
std::vector<bool> visited(graph_in.num_node_ids(), false);
|
||||||
|
while (!stack.empty()) {
|
||||||
|
Work w = stack.back();
|
||||||
|
stack.pop_back();
|
||||||
|
Node* n = w.node;
|
||||||
|
|
||||||
|
if (w.leave) {
|
||||||
|
TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
|
||||||
|
n, send_node, dummy_node_images, library, &copied_node_images,
|
||||||
|
graph_out.get()));
|
||||||
|
} else {
|
||||||
|
if (visited[n->id()]) continue;
|
||||||
|
visited[n->id()] = true;
|
||||||
|
|
||||||
|
// Arrange to revisit when all done with all inputs.
|
||||||
|
stack.push_back(Work{n, true});
|
||||||
|
|
||||||
|
bool has_parent_with_unknown_shape = false;
|
||||||
|
for (const Edge* in_edge : n->in_edges()) {
|
||||||
|
if (!in_edge->IsControlEdge()) {
|
||||||
|
Node* src_node = in_edge->src();
|
||||||
|
int src_port = in_edge->src_output();
|
||||||
|
shape_inference::InferenceContext* context =
|
||||||
|
shape_refiner.GetContext(src_node);
|
||||||
|
shape_inference::ShapeHandle shape = context->output(src_port);
|
||||||
|
if (context->FullyDefined(shape)) {
|
||||||
|
// This ancestor has known shape, so instead of adding it to the
|
||||||
|
// stack, add a dummy node with that shape to graph_out and
|
||||||
|
// continue.
|
||||||
|
TensorShapeProto proto;
|
||||||
|
context->ShapeHandleToProto(shape, &proto);
|
||||||
|
dummy_node_images[src_node] = AddDummyShapedNode(
|
||||||
|
src_node->output_type(src_port), proto, graph_out.get());
|
||||||
|
if (n == send_node) {
|
||||||
|
(*static_shape_out)[in_edge->dst_input()] = proto;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!visited[src_node->id()]) {
|
||||||
|
has_parent_with_unknown_shape = true;
|
||||||
|
stack.push_back({src_node, false});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!has_parent_with_unknown_shape) {
|
||||||
|
if (n == send_node) {
|
||||||
|
// The shapes of all the inputs to send_node are statically known. We
|
||||||
|
// won't have to do any inference at compile time so return now: the
|
||||||
|
// shapes were stored in static_shape_out above.
|
||||||
|
graphdef_out->reset();
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
// Any shape that is being processed is either the original send node
|
||||||
|
// or has at least one output with statically-unknown shape. If the
|
||||||
|
// latter and it doesn't have any inputs with statically-unknown
|
||||||
|
// shape, then check that it is of the recv nodes that we can fill in
|
||||||
|
// the shape of at run-time later. If it isn't one of those, then we
|
||||||
|
// won't have any additional knowledge at compile time, so we already
|
||||||
|
// know we won't be able to do shape inference and we can return an
|
||||||
|
// error now.
|
||||||
|
if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Shape inference is not possible for outside_compilation "
|
||||||
|
"SendFromHost node ",
|
||||||
|
send_node->name(), " because shape of node ", n->name(),
|
||||||
|
" will not be known at compilation time.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
graphdef_out->reset(new GraphDef());
|
||||||
|
graph_out->ToGraphDef(graphdef_out->get());
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Encapsulator::MakePrunedGraphCopyAndInline(
|
||||||
|
const Graph& graph, const std::vector<Node*>& sink_nodes,
|
||||||
|
std::unique_ptr<Graph>* pruned_graph,
|
||||||
|
std::unordered_map<const Node*, Node*>* node_images,
|
||||||
|
FunctionLibraryDefinition* library) {
|
||||||
|
// First copy all ancestor nodes of sink_nodes into a new graph.
|
||||||
|
pruned_graph->reset(new Graph(library));
|
||||||
|
(*pruned_graph)->set_versions(graph.versions());
|
||||||
|
ReverseDFSFrom(graph, sink_nodes,
|
||||||
|
/*enter=*/nullptr,
|
||||||
|
/*leave=*/[&](Node* n) {
|
||||||
|
if (!n->IsSource()) {
|
||||||
|
Node* copied = (*pruned_graph)->CopyNode(n);
|
||||||
|
node_images->emplace(n, copied);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add all the edges between copied nodes.
|
||||||
|
for (auto entry : *node_images) {
|
||||||
|
const Node* orig = entry.first;
|
||||||
|
Node* image = entry.second;
|
||||||
|
for (const Edge* out_edge : orig->out_edges()) {
|
||||||
|
auto iter = node_images->find(out_edge->dst());
|
||||||
|
if (iter != node_images->end()) {
|
||||||
|
// The source and destination are both in the copied graph.
|
||||||
|
(*pruned_graph)
|
||||||
|
->AddEdge(image, out_edge->src_output(), iter->second,
|
||||||
|
out_edge->dst_input());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find all the function call nodes, and inline them.
|
||||||
|
std::vector<Node*> function_nodes;
|
||||||
|
for (auto node : (*pruned_graph)->nodes()) {
|
||||||
|
const OpRegistrationData* op_reg_data;
|
||||||
|
TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
|
||||||
|
if (op_reg_data->is_function_op) {
|
||||||
|
function_nodes.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto node : function_nodes) {
|
||||||
|
VLOG(2) << "Inlining function " << node->name();
|
||||||
|
const FunctionDef* fdef = library->Find(node->type_string());
|
||||||
|
if (fdef == nullptr) {
|
||||||
|
return errors::Internal("Failed to find function ", node->type_string(),
|
||||||
|
" in function library.");
|
||||||
|
}
|
||||||
|
FunctionBody* fbody = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
FunctionDefToBodyHelper(*fdef, node->attrs(), library,
|
||||||
|
[library](const string& op, const OpDef** sig) {
|
||||||
|
return library->LookUpOpDef(op, sig);
|
||||||
|
},
|
||||||
|
&fbody));
|
||||||
|
InlineFunctionBody(*library, pruned_graph->get(), node, fbody);
|
||||||
|
delete fbody;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Encapsulator::MakeGraphForOutsideCompilationSends(
|
||||||
|
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
||||||
|
ShapeRefiner* shape_refiner,
|
||||||
|
std::unordered_map<const Node*, Node*>* node_images,
|
||||||
|
FunctionLibraryDefinition* library) {
|
||||||
|
// Find all the send_from_host nodes in all subgraphs, to use as roots for the
|
||||||
|
// pruning.
|
||||||
|
std::vector<Node*> send_from_host_nodes;
|
||||||
|
for (auto& subgraph_entry : subgraphs_) {
|
||||||
|
Subgraph& subgraph = subgraph_entry.second;
|
||||||
|
std::vector<string> outside_compilation_names;
|
||||||
|
subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
|
||||||
|
for (const auto& name : outside_compilation_names) {
|
||||||
|
Node* send_node = subgraph.GetSendFromHostNode(name);
|
||||||
|
if (send_node != nullptr) {
|
||||||
|
send_from_host_nodes.push_back(send_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a copy of all the graph nodes needed to evaluate the send_from_host
|
||||||
|
// nodes, inlining any functions as needed.
|
||||||
|
TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
|
||||||
|
graph, send_from_host_nodes, pruned_graph, node_images, library));
|
||||||
|
|
||||||
|
// Perform shape inference on the pruned graph.
|
||||||
|
shape_refiner->set_require_shape_inference_fns(false);
|
||||||
|
FixupSourceAndSinkEdges(pruned_graph->get());
|
||||||
|
std::vector<Node*> post_order;
|
||||||
|
GetReversePostOrder(*(*pruned_graph), &post_order);
|
||||||
|
for (auto node : post_order) {
|
||||||
|
// Ignore the status returned by the shape_refiner. At this point we want
|
||||||
|
// the best effort shapes, even if no shape function is registered for a
|
||||||
|
// node.
|
||||||
|
Status status = shape_refiner->AddNode(node);
|
||||||
|
if (!status.ok()) {
|
||||||
|
VLOG(1) << "Shape inference failed for node: " << status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
|
||||||
|
Graph* graph_out, FunctionLibraryDefinition* library) {
|
||||||
|
std::unique_ptr<Graph> pruned_graph;
|
||||||
|
ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
|
||||||
|
std::unordered_map<const Node*, Node*> node_images;
|
||||||
|
TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
|
||||||
|
*graph_out, &pruned_graph, &shape_refiner, &node_images, library));
|
||||||
|
|
||||||
|
for (auto& subgraph_entry : subgraphs_) {
|
||||||
|
Subgraph& subgraph = subgraph_entry.second;
|
||||||
|
// Find all the recv_at_host nodes in this subgraph.
|
||||||
|
std::vector<string> outside_compilation_names;
|
||||||
|
subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
|
||||||
|
std::unordered_set<string> recv_at_host_names;
|
||||||
|
for (const auto& name : outside_compilation_names) {
|
||||||
|
Node* recv_node = subgraph.GetRecvAtHostNode(name);
|
||||||
|
if (recv_node != nullptr) {
|
||||||
|
recv_at_host_names.insert(recv_node->name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// For each send_from_host node, do as much shape inference as possible
|
||||||
|
// without knowing the shape of the recv_at_host nodes, and store the
|
||||||
|
// result, along with enough information to complete the job at compile time
|
||||||
|
// once the recv_at_host shapes are known.
|
||||||
|
for (const auto& name : outside_compilation_names) {
|
||||||
|
Node* send_node = subgraph.GetSendFromHostNode(name);
|
||||||
|
std::vector<TensorShapeProto> static_shape;
|
||||||
|
std::unique_ptr<GraphDef> graphdef;
|
||||||
|
if (send_node != nullptr) {
|
||||||
|
TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
|
||||||
|
*pruned_graph, shape_refiner, recv_at_host_names,
|
||||||
|
node_images[send_node], library, &static_shape, &graphdef));
|
||||||
|
if (graphdef == nullptr) {
|
||||||
|
VLOG(2) << "Send node " << send_node->name() << " shapes";
|
||||||
|
for (int i = 0; i < static_shape.size(); ++i) {
|
||||||
|
VLOG(2) << static_shape[i].DebugString();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
VLOG(2) << "Send node " << send_node->name() << " graph\n"
|
||||||
|
<< graphdef->DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
|
||||||
|
}
|
||||||
|
if (!outside_compilation_names.empty()) {
|
||||||
|
TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
|
||||||
|
FunctionLibraryDefinition* library) {
|
||||||
// Map from nodes in the input graph to nodes in the output graph.
|
// Map from nodes in the input graph to nodes in the output graph.
|
||||||
std::unordered_map<const Node*, Node*> node_images;
|
std::unordered_map<const Node*, Node*> node_images;
|
||||||
|
|
||||||
@ -1522,6 +1968,9 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking,
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
|
AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetShapeInfoForOutsideCompilationSends(graph_out, library));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1545,7 +1994,7 @@ Status EncapsulateSubgraphsInFunctions(
|
|||||||
std::unique_ptr<Graph> out(new Graph(library));
|
std::unique_ptr<Graph> out(new Graph(library));
|
||||||
out->set_versions(graph_in.versions());
|
out->set_versions(graph_in.versions());
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
encapsulator.BuildOutputGraph(parallel_checking, out.get()));
|
encapsulator.BuildOutputGraph(parallel_checking, out.get(), library));
|
||||||
|
|
||||||
*graph_out = std::move(out);
|
*graph_out = std::move(out);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -29,17 +29,181 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
template <class Tkey, class Tvalue>
|
||||||
string* diff) {
|
bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
|
||||||
// TODO(phawkins) use a more sophisticated equality test.
|
const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
|
||||||
if (a.DebugString() != b.DebugString()) {
|
const std::function<string(const Tkey&)>& key_to_string,
|
||||||
|
const std::function<string(const Tvalue&)>& value_to_string,
|
||||||
|
const std::function<bool(const Tkey&, const Tvalue&,
|
||||||
|
const Tvalue&)>& compare,
|
||||||
|
const string& map_name, string* diff) {
|
||||||
|
for (const auto& elt_a : a) {
|
||||||
|
const auto iter = b.find(elt_a.first);
|
||||||
|
if (iter == b.end()) {
|
||||||
if (diff) {
|
if (diff) {
|
||||||
*diff = strings::StrCat("Definition mismatch for function ",
|
*diff = strings::StrCat(
|
||||||
a.signature().name(), ", expected:\n",
|
map_name, " expected: contains element with key '",
|
||||||
|
key_to_string(elt_a.first), "' got: map has no such element");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!compare(elt_a.first, elt_a.second, iter->second)) {
|
||||||
|
if (diff) {
|
||||||
|
*diff = strings::StrCat(map_name, " expected: element with key '",
|
||||||
|
key_to_string(elt_a.first), " has value '",
|
||||||
|
value_to_string(elt_a.second), "' got: '",
|
||||||
|
value_to_string(iter->second), "'");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& elt_b : b) {
|
||||||
|
const auto iter = a.find(elt_b.first);
|
||||||
|
if (iter == a.end()) {
|
||||||
|
if (diff) {
|
||||||
|
*diff = strings::StrCat(map_name, " got: contains element with key '",
|
||||||
|
key_to_string(elt_b.first),
|
||||||
|
"' expected: map has no such element");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
|
||||||
|
const string& diff_preamble, string* diff) {
|
||||||
|
if (a.op() != b.op()) {
|
||||||
|
if (diff) {
|
||||||
|
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||||
|
", expected op '", a.op(), "' got '", b.op());
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (a.device() != b.device()) {
|
||||||
|
if (diff) {
|
||||||
|
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||||
|
", expected device '", a.device(), "' got '",
|
||||||
|
b.device());
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (a.input_size() != b.input_size()) {
|
||||||
|
if (diff) {
|
||||||
|
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||||
|
", expected ", a.input_size(), " inputs got ",
|
||||||
|
b.input_size(), " expected:\n", a.DebugString(),
|
||||||
|
"\ngot:\n", b.DebugString());
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < a.input_size(); ++i) {
|
||||||
|
if (a.input(i) != b.input(i)) {
|
||||||
|
if (diff) {
|
||||||
|
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||||
|
" input ", i, ", expected ", a.input(i),
|
||||||
|
" got ", b.input(i), " expected:\n",
|
||||||
a.DebugString(), "\ngot:\n", b.DebugString());
|
a.DebugString(), "\ngot:\n", b.DebugString());
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return EqualProtoMap<string, AttrValue>(
|
||||||
|
a.attr(), b.attr(), [](const string& s) { return s; },
|
||||||
|
[](const AttrValue& v) { return v.DebugString(); },
|
||||||
|
[](const string& key, const AttrValue& av, const AttrValue& bv) {
|
||||||
|
if (key == "shape_inference_graph") {
|
||||||
|
// Default serialization of GraphDef is unstable because maps don't
|
||||||
|
// serialize deterministically. Rather than go through the hoops to
|
||||||
|
// turn on deterministic serialization of this attr just for this
|
||||||
|
// test, add logic here to compare determinstically.
|
||||||
|
GraphDef ga;
|
||||||
|
if (!ga.ParseFromString(av.s())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
GraphDef gb;
|
||||||
|
if (!gb.ParseFromString(bv.s())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return EqualGraphDef(ga, gb, nullptr);
|
||||||
|
} else {
|
||||||
|
return av.DebugString() == bv.DebugString();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
|
||||||
|
diff);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||||
|
string* diff) {
|
||||||
|
if (a.signature().DebugString() != b.signature().DebugString()) {
|
||||||
|
if (diff) {
|
||||||
|
*diff = strings::StrCat("Signature mismatch for function ",
|
||||||
|
a.signature().name(), ", expected:\n",
|
||||||
|
a.signature().DebugString(), "\ngot:\n",
|
||||||
|
b.signature().DebugString());
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!EqualProtoMap<string, AttrValue>(
|
||||||
|
a.attr(), b.attr(), [](const string& s) { return s; },
|
||||||
|
[](const AttrValue& v) { return v.DebugString(); },
|
||||||
|
[](const string& key, const AttrValue& av, const AttrValue& bv) {
|
||||||
|
return av.DebugString() == bv.DebugString();
|
||||||
|
},
|
||||||
|
strings::StrCat("attr mismatch for function ", a.signature().name()),
|
||||||
|
diff)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!EqualProtoMap<string, string>(
|
||||||
|
a.ret(), b.ret(), [](const string& s) { return s; },
|
||||||
|
[](const string& s) { return s; },
|
||||||
|
[](const string& key, const string& av, const string& bv) {
|
||||||
|
return av == bv;
|
||||||
|
},
|
||||||
|
strings::StrCat("ret mismatch for function ", a.signature().name()),
|
||||||
|
diff)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < a.node_def_size(); ++i) {
|
||||||
|
bool found = false;
|
||||||
|
for (int j = 0; j < b.node_def_size(); ++j) {
|
||||||
|
if (a.node_def(i).name() == b.node_def(j).name()) {
|
||||||
|
if (!EqualFunctionNodeDef(
|
||||||
|
a.node_def(i), b.node_def(j),
|
||||||
|
strings::StrCat("Function ", a.signature().name()), diff)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
if (diff) {
|
||||||
|
*diff = strings::StrCat("Function ", a.signature().name(),
|
||||||
|
", expected: has node '", a.node_def(i).name(),
|
||||||
|
"' got: no node of that name");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < b.node_def_size(); ++i) {
|
||||||
|
bool found = false;
|
||||||
|
for (int j = 0; j < a.node_def_size(); ++j) {
|
||||||
|
if (b.node_def(i).name() == a.node_def(j).name()) {
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
if (diff) {
|
||||||
|
*diff = strings::StrCat("Function ", a.signature().name(),
|
||||||
|
", got: has node '", b.node_def(i).name(),
|
||||||
|
"' expected: no node of that name");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,29 +248,64 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
|
|||||||
|
|
||||||
// TODO(misard): remove these fake registrations once there are real Ops to be
|
// TODO(misard): remove these fake registrations once there are real Ops to be
|
||||||
// compiled.
|
// compiled.
|
||||||
REGISTER_OP("_XlaSendToHost")
|
REGISTER_OP("_XlaHostCompute")
|
||||||
.Input("input: dtypes")
|
.Input("inputs: Tinputs")
|
||||||
.Attr("dtypes: list(type) >= 0");
|
.Output("outputs: Toutputs")
|
||||||
|
.Attr("Tinputs: list(type) >= 0")
|
||||||
REGISTER_OP("_XlaRecvFromHost")
|
.Attr("Toutputs: list(type) >= 0")
|
||||||
.Output("output: dtypes")
|
.Attr("key: string")
|
||||||
.Attr("dtypes: list(type) >= 0");
|
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||||
|
|
||||||
REGISTER_OP("_XlaSendFromHost")
|
REGISTER_OP("_XlaSendFromHost")
|
||||||
.Input("input: dtypes")
|
.Input("input: Tinputs")
|
||||||
.Attr("dtypes: list(type) >= 0");
|
.Attr("Tinputs: list(type) >= 0")
|
||||||
|
.Attr("key: string")
|
||||||
|
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||||
|
|
||||||
REGISTER_OP("_XlaRecvAtHost")
|
REGISTER_OP("_XlaRecvAtHost")
|
||||||
.Output("output: dtypes")
|
.Output("output: Toutputs")
|
||||||
.Attr("dtypes: list(type) >= 0");
|
.Attr("Toutputs: list(type) >= 0")
|
||||||
|
.Attr("key: string")
|
||||||
|
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||||
|
|
||||||
REGISTER_OP("InputTest").Output("o: float");
|
REGISTER_OP("InputTest")
|
||||||
|
.Output("o: float")
|
||||||
|
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->UnknownShape());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
|
REGISTER_OP("InputTestShaped")
|
||||||
|
.Output("o: float")
|
||||||
|
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->Vector(2));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("UnaryTest")
|
||||||
|
.Input("a: float")
|
||||||
|
.Output("o: float")
|
||||||
|
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||||
|
::tensorflow::shape_inference::ShapeHandle o;
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
|
||||||
|
c->set_output(0, o);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
REGISTER_OP("BinaryTest")
|
REGISTER_OP("BinaryTest")
|
||||||
.Input("a: float")
|
.Input("a: float")
|
||||||
.Input("b: float")
|
.Input("b: float")
|
||||||
.Output("o: float");
|
.Output("o: float")
|
||||||
|
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||||
|
::tensorflow::shape_inference::ShapeHandle o;
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
|
||||||
|
c->set_output(0, o);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
REGISTER_OP("BinaryTest2")
|
||||||
|
.Input("a: float")
|
||||||
|
.Input("b: float")
|
||||||
|
.Output("o: float")
|
||||||
|
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||||
|
|
||||||
REGISTER_OP("AddNLikeTest")
|
REGISTER_OP("AddNLikeTest")
|
||||||
.Input("inputs: N * T")
|
.Input("inputs: N * T")
|
||||||
@ -124,22 +323,48 @@ Node* Input(const GraphDefBuilder::Options& opts) {
|
|||||||
return ops::SourceOp("InputTest", opts);
|
return ops::SourceOp("InputTest", opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
Node* RecvAtHost(const gtl::ArraySlice<DataType>& dtypes,
|
Node* InputShaped(const GraphDefBuilder::Options& opts) {
|
||||||
|
return ops::SourceOp("InputTestShaped", opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* KnownShape(const gtl::ArraySlice<int>& shape,
|
||||||
|
const GraphDefBuilder::Options& opts) {
|
||||||
|
if (opts.HaveError()) return nullptr;
|
||||||
|
NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
|
||||||
|
opts.op_registry());
|
||||||
|
TensorProto value;
|
||||||
|
value.set_dtype(DT_FLOAT);
|
||||||
|
for (int dim : shape) {
|
||||||
|
value.mutable_tensor_shape()->add_dim()->set_size(dim);
|
||||||
|
}
|
||||||
|
return opts.WithAttr("value", value)
|
||||||
|
.WithAttr("dtype", DT_FLOAT)
|
||||||
|
.FinalizeBuilder(&node_builder);
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* RecvAtHost(const string& key, const gtl::ArraySlice<DataType>& dtypes,
|
||||||
const GraphDefBuilder::Options& opts) {
|
const GraphDefBuilder::Options& opts) {
|
||||||
if (opts.HaveError()) return nullptr;
|
if (opts.HaveError()) return nullptr;
|
||||||
NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
|
NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
|
||||||
"_XlaRecvAtHost", opts.op_registry());
|
"_XlaRecvAtHost", opts.op_registry());
|
||||||
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
|
return opts.WithAttr("Toutputs", dtypes)
|
||||||
|
.WithAttr("key", key)
|
||||||
|
.FinalizeBuilder(&node_builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
Node* SendFromHost(const std::vector<ops::NodeOut>& inputs,
|
Node* SendFromHost(const string& key, const std::vector<ops::NodeOut>& inputs,
|
||||||
const gtl::ArraySlice<DataType>& dtypes,
|
|
||||||
const GraphDefBuilder::Options& opts) {
|
const GraphDefBuilder::Options& opts) {
|
||||||
if (opts.HaveError()) return nullptr;
|
if (opts.HaveError()) return nullptr;
|
||||||
NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
|
NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
|
||||||
"_XlaSendFromHost", opts.op_registry());
|
"_XlaSendFromHost", opts.op_registry());
|
||||||
node_builder.Input(inputs);
|
node_builder.Input(inputs);
|
||||||
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
|
std::vector<DataType> dtypes;
|
||||||
|
for (const auto& node : inputs) {
|
||||||
|
dtypes.push_back(node.dt);
|
||||||
|
}
|
||||||
|
return opts.WithAttr("key", key)
|
||||||
|
.WithAttr("Tinputs", dtypes)
|
||||||
|
.FinalizeBuilder(&node_builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
||||||
@ -151,6 +376,11 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b,
|
|||||||
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
|
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b,
|
||||||
|
const GraphDefBuilder::Options& opts) {
|
||||||
|
return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts);
|
||||||
|
}
|
||||||
|
|
||||||
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
|
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
|
||||||
const GraphDefBuilder::Options& opts) {
|
const GraphDefBuilder::Options& opts) {
|
||||||
if (opts.HaveError()) return nullptr;
|
if (opts.HaveError()) return nullptr;
|
||||||
@ -576,6 +806,21 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
|||||||
FunctionDefLibrary library_expected;
|
FunctionDefLibrary library_expected;
|
||||||
GraphDef graphdef_expected;
|
GraphDef graphdef_expected;
|
||||||
|
|
||||||
|
string shape_string_expected;
|
||||||
|
{
|
||||||
|
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
|
||||||
|
Node* recv =
|
||||||
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||||
|
shape.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
|
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
||||||
|
shape.opts().WithName("E"));
|
||||||
|
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
|
shape.opts().WithName("outside_compilation_F1_O1_send"));
|
||||||
|
GraphDef shape_graph;
|
||||||
|
TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
|
||||||
|
EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
|
||||||
|
}
|
||||||
|
|
||||||
*library_expected.add_function() = test::function::XTimesTwo();
|
*library_expected.add_function() = test::function::XTimesTwo();
|
||||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||||
@ -584,19 +829,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
|||||||
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
|
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
|
||||||
{{"F"},
|
{{"F"},
|
||||||
"BinaryTest",
|
"BinaryTest",
|
||||||
{"C:o:0", "outside_compilation_O1_recv:output:0"},
|
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
|
||||||
{},
|
{},
|
||||||
{"outside_compilation_O1_recv"}},
|
{"outside_compilation_O1_host_compute"}},
|
||||||
{{"outside_compilation_O1_send"},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
"_XlaSendToHost",
|
"_XlaHostCompute",
|
||||||
{"C:o:0", "c:o:0"},
|
{"C:o:0", "c:o:0"},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||||
|
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"key", "host_compute_channel_F1_O1"},
|
||||||
|
{"shape_inference_graph", shape_string_expected},
|
||||||
|
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||||
{"c"}},
|
{"c"}},
|
||||||
{{"outside_compilation_O1_recv"},
|
|
||||||
"_XlaRecvFromHost",
|
|
||||||
{},
|
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
|
||||||
{"outside_compilation_O1_send"}},
|
|
||||||
},
|
},
|
||||||
{{"f_0_retval", "F:o:0"}});
|
{{"f_0_retval", "F:o:0"}});
|
||||||
|
|
||||||
@ -612,11 +856,11 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
|||||||
Node* call = b2.opts().FinalizeBuilder(&node_builder);
|
Node* call = b2.opts().FinalizeBuilder(&node_builder);
|
||||||
|
|
||||||
Node* recv =
|
Node* recv =
|
||||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
||||||
b2.opts().WithName("E").WithControlInputs({recv, b}));
|
b2.opts().WithName("E").WithControlInputs({recv, b}));
|
||||||
Node* send = SendFromHost({e}, {DT_FLOAT},
|
Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
b2.opts()
|
b2.opts()
|
||||||
.WithName("outside_compilation_F1_O1_send")
|
.WithName("outside_compilation_F1_O1_send")
|
||||||
.WithControlInput(e));
|
.WithControlInput(e));
|
||||||
@ -674,37 +918,71 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
|||||||
FunctionDefLibrary library_expected;
|
FunctionDefLibrary library_expected;
|
||||||
GraphDef graphdef_expected;
|
GraphDef graphdef_expected;
|
||||||
|
|
||||||
|
string shape_string_expected_1;
|
||||||
|
{
|
||||||
|
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
|
||||||
|
Node* recv =
|
||||||
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||||
|
shape1.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
|
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
||||||
|
shape1.opts().WithName("E"));
|
||||||
|
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
|
shape1.opts().WithName("outside_compilation_F1_O1_send"));
|
||||||
|
GraphDef shape1_graph;
|
||||||
|
TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph));
|
||||||
|
EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1));
|
||||||
|
}
|
||||||
|
|
||||||
|
string shape_string_expected_2;
|
||||||
|
{
|
||||||
|
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
|
||||||
|
Node* recv1 =
|
||||||
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||||
|
shape2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
|
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
||||||
|
shape2.opts().WithName("E"));
|
||||||
|
Node* recv2 =
|
||||||
|
RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
|
||||||
|
shape2.opts().WithName("outside_compilation_F1_O2_recv"));
|
||||||
|
Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
|
||||||
|
SendFromHost("host_compute_channel_F1_O2", {h},
|
||||||
|
shape2.opts().WithName("outside_compilation_F1_O2_send"));
|
||||||
|
GraphDef shape2_graph;
|
||||||
|
TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph));
|
||||||
|
EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2));
|
||||||
|
}
|
||||||
|
|
||||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
|
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
|
||||||
{
|
{
|
||||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
|
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
|
||||||
{{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}},
|
{{"I"},
|
||||||
|
"UnaryTest",
|
||||||
|
{"outside_compilation_O2_host_compute:outputs:0"}},
|
||||||
{{"F"},
|
{{"F"},
|
||||||
"BinaryTest",
|
"BinaryTest",
|
||||||
{"C:o:0", "outside_compilation_O1_recv:output:0"},
|
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
|
||||||
{},
|
{},
|
||||||
{"outside_compilation_O1_recv"}},
|
{"outside_compilation_O1_host_compute"}},
|
||||||
{{"outside_compilation_O2_send"},
|
{{"outside_compilation_O2_host_compute"},
|
||||||
"_XlaSendToHost",
|
"_XlaHostCompute",
|
||||||
{"D:o:0", "F:o:0"},
|
{"D:o:0", "F:o:0"},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||||
|
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"key", "host_compute_channel_F1_O2"},
|
||||||
|
{"shape_inference_graph", shape_string_expected_2},
|
||||||
|
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||||
{"F"}},
|
{"F"}},
|
||||||
{{"outside_compilation_O1_send"},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
"_XlaSendToHost",
|
"_XlaHostCompute",
|
||||||
{"C:o:0", "D:o:0"},
|
{"C:o:0", "D:o:0"},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||||
|
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"key", "host_compute_channel_F1_O1"},
|
||||||
|
{"shape_inference_graph", shape_string_expected_1},
|
||||||
|
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||||
{"D"}},
|
{"D"}},
|
||||||
{{"outside_compilation_O2_recv"},
|
|
||||||
"_XlaRecvFromHost",
|
|
||||||
{},
|
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
|
||||||
{"outside_compilation_O2_send"}},
|
|
||||||
{{"outside_compilation_O1_recv"},
|
|
||||||
"_XlaRecvFromHost",
|
|
||||||
{},
|
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
|
||||||
{"outside_compilation_O1_send"}},
|
|
||||||
},
|
},
|
||||||
{{"i_0_retval", "I:o:0"}});
|
{{"i_0_retval", "I:o:0"}});
|
||||||
|
|
||||||
@ -720,23 +998,24 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
|||||||
Node* call = b2.opts().FinalizeBuilder(&node_builder);
|
Node* call = b2.opts().FinalizeBuilder(&node_builder);
|
||||||
|
|
||||||
Node* recv1 =
|
Node* recv1 =
|
||||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
||||||
b2.opts().WithName("E").WithControlInputs({recv1, b}));
|
b2.opts().WithName("E").WithControlInputs({recv1, b}));
|
||||||
Node* send1 = SendFromHost({e}, {DT_FLOAT},
|
Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
b2.opts()
|
b2.opts()
|
||||||
.WithName("outside_compilation_F1_O1_send")
|
.WithName("outside_compilation_F1_O1_send")
|
||||||
.WithControlInput(e));
|
.WithControlInput(e));
|
||||||
|
|
||||||
Node* recv2 =
|
Node* recv2 =
|
||||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
|
||||||
b2.opts().WithName("outside_compilation_F1_O2_recv"));
|
b2.opts().WithName("outside_compilation_F1_O2_recv"));
|
||||||
Node* g = Binary(e, ops::NodeOut(recv2, 1),
|
Node* g = Binary(e, ops::NodeOut(recv2, 1),
|
||||||
b2.opts().WithName("G").WithControlInputs({recv2, e}));
|
b2.opts().WithName("G").WithControlInputs({recv2, e}));
|
||||||
Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
|
Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
|
||||||
Node* send2 = SendFromHost(
|
Node* send2 =
|
||||||
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send"));
|
SendFromHost("host_compute_channel_F1_O2", {h},
|
||||||
|
b2.opts().WithName("outside_compilation_F1_O2_send"));
|
||||||
|
|
||||||
Node* s = NoOp(b2.opts()
|
Node* s = NoOp(b2.opts()
|
||||||
.WithName("F1_sequencer")
|
.WithName("F1_sequencer")
|
||||||
@ -758,8 +1037,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||||
Node* a = Input(b1.opts().WithName("A"));
|
Node* a = InputShaped(b1.opts().WithName("A"));
|
||||||
Node* b = Input(b1.opts().WithName("B"));
|
Node* b = InputShaped(b1.opts().WithName("B"));
|
||||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||||
Node* d =
|
Node* d =
|
||||||
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
||||||
@ -791,6 +1070,24 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
|||||||
FunctionDefLibrary library_expected;
|
FunctionDefLibrary library_expected;
|
||||||
GraphDef graphdef_expected;
|
GraphDef graphdef_expected;
|
||||||
|
|
||||||
|
string shape_string_expected;
|
||||||
|
{
|
||||||
|
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
|
||||||
|
Node* recv =
|
||||||
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||||
|
shape.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
|
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
||||||
|
shape.opts().WithName("E"));
|
||||||
|
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
|
shape.opts().WithName("outside_compilation_F1_O1_send"));
|
||||||
|
GraphDef shape_graph;
|
||||||
|
TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
|
||||||
|
EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShapeProto shape_proto_expected;
|
||||||
|
shape_proto_expected.add_dim()->set_size(2);
|
||||||
|
|
||||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||||
"F1", {"a_0_arg:float", "b_0_arg:float"},
|
"F1", {"a_0_arg:float", "b_0_arg:float"},
|
||||||
{"f_0_retval:float", "d_0_retval:float"}, {},
|
{"f_0_retval:float", "d_0_retval:float"}, {},
|
||||||
@ -799,19 +1096,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
|||||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||||
{{"F"},
|
{{"F"},
|
||||||
"BinaryTest",
|
"BinaryTest",
|
||||||
{"C:o:0", "outside_compilation_O1_recv:output:0"},
|
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
|
||||||
{},
|
{},
|
||||||
{"outside_compilation_O1_recv"}},
|
{"outside_compilation_O1_host_compute"}},
|
||||||
{{"outside_compilation_O1_send"},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
"_XlaSendToHost",
|
"_XlaHostCompute",
|
||||||
{"C:o:0", "D:o:0"},
|
{"C:o:0", "D:o:0"},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||||
|
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"key", "host_compute_channel_F1_O1"},
|
||||||
|
{"shape_inference_graph", shape_string_expected},
|
||||||
|
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||||
{"D"}},
|
{"D"}},
|
||||||
{{"outside_compilation_O1_recv"},
|
|
||||||
"_XlaRecvFromHost",
|
|
||||||
{},
|
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
|
||||||
{"outside_compilation_O1_send"}},
|
|
||||||
},
|
},
|
||||||
{{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
|
{{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
|
||||||
|
|
||||||
@ -822,16 +1118,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
|||||||
{{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
|
{{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
|
||||||
{{"I"},
|
{{"I"},
|
||||||
"BinaryTest",
|
"BinaryTest",
|
||||||
{"f_0_arg", "outside_compilation_O1_recv:output:0"}},
|
{"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
|
||||||
{{"outside_compilation_O1_send"},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
"_XlaSendToHost",
|
"_XlaHostCompute",
|
||||||
{"G:o:0"},
|
{"G:o:0"},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
{{"outside_compilation_O1_recv"},
|
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
"_XlaRecvFromHost",
|
{"key", "host_compute_channel_F2_O1"},
|
||||||
{},
|
{"shape_inference_graph", ""},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
{"shapes",
|
||||||
{"outside_compilation_O1_send"}},
|
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
|
||||||
},
|
},
|
||||||
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
|
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
|
||||||
|
|
||||||
@ -839,15 +1135,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
|||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||||
Node* a = Input(b2.opts().WithName("A"));
|
Node* a = InputShaped(b2.opts().WithName("A"));
|
||||||
Node* b = Input(b2.opts().WithName("B"));
|
Node* b = InputShaped(b2.opts().WithName("B"));
|
||||||
|
|
||||||
Node* recv1 =
|
Node* recv1 =
|
||||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
||||||
b2.opts().WithName("E").WithControlInputs({recv1, b}));
|
b2.opts().WithName("E").WithControlInputs({recv1, b}));
|
||||||
Node* send1 = SendFromHost({e}, {DT_FLOAT},
|
Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
b2.opts()
|
b2.opts()
|
||||||
.WithName("outside_compilation_F1_O1_send")
|
.WithName("outside_compilation_F1_O1_send")
|
||||||
.WithControlInput(e));
|
.WithControlInput(e));
|
||||||
@ -857,12 +1153,14 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
|||||||
Node* s1 = NoOp(
|
Node* s1 = NoOp(
|
||||||
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
|
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
|
||||||
|
|
||||||
Node* recv2 = RecvAtHost(
|
Node* recv2 =
|
||||||
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
|
RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT},
|
||||||
|
b2.opts().WithName("outside_compilation_F2_O1_recv"));
|
||||||
Node* h = Binary(ops::NodeOut(call1, 1), recv2,
|
Node* h = Binary(ops::NodeOut(call1, 1), recv2,
|
||||||
b2.opts().WithName("H").WithControlInput(s1));
|
b2.opts().WithName("H").WithControlInput(s1));
|
||||||
Node* send2 = SendFromHost(
|
Node* send2 =
|
||||||
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send"));
|
SendFromHost("host_compute_channel_F2_O1", {h},
|
||||||
|
b2.opts().WithName("outside_compilation_F2_O1_send"));
|
||||||
|
|
||||||
NodeBuilder node_builder2("F2", "F2", lib_def.get());
|
NodeBuilder node_builder2("F2", "F2", lib_def.get());
|
||||||
node_builder2.Input(e).Input(call1);
|
node_builder2.Input(e).Input(call1);
|
||||||
@ -888,7 +1186,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||||
Node* a = Input(b1.opts().WithName("A"));
|
Node* a = InputShaped(b1.opts().WithName("A"));
|
||||||
Node* b = Input(b1.opts().WithName("B"));
|
Node* b = Input(b1.opts().WithName("B"));
|
||||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||||
Node* d =
|
Node* d =
|
||||||
@ -908,6 +1206,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
|||||||
FunctionDefLibrary library_expected;
|
FunctionDefLibrary library_expected;
|
||||||
GraphDef graphdef_expected;
|
GraphDef graphdef_expected;
|
||||||
|
|
||||||
|
TensorShapeProto shape_proto_expected;
|
||||||
|
shape_proto_expected.add_dim()->set_size(2);
|
||||||
|
|
||||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||||
{
|
{
|
||||||
@ -915,11 +1216,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
|||||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||||
{{"F"},
|
{{"F"},
|
||||||
"BinaryTest",
|
"BinaryTest",
|
||||||
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
|
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
|
||||||
{{"outside_compilation_O1_recv"},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
"_XlaRecvFromHost",
|
"_XlaHostCompute",
|
||||||
{},
|
{},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
{{"Tinputs", gtl::ArraySlice<DataType>({})},
|
||||||
|
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"key", "host_compute_channel_F1_O1"},
|
||||||
|
{"shape_inference_graph", ""},
|
||||||
|
{"shapes",
|
||||||
|
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
|
||||||
},
|
},
|
||||||
{{"f_0_retval", "F:o:0"}});
|
{{"f_0_retval", "F:o:0"}});
|
||||||
|
|
||||||
@ -927,12 +1233,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
|||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||||
Node* a = Input(b2.opts().WithName("A"));
|
Node* a = InputShaped(b2.opts().WithName("A"));
|
||||||
Node* b = Input(b2.opts().WithName("B"));
|
Node* b = Input(b2.opts().WithName("B"));
|
||||||
|
|
||||||
Node* e = Unary(a, b2.opts().WithName("E"));
|
Node* e = Unary(a, b2.opts().WithName("E"));
|
||||||
Node* send1 = SendFromHost(
|
Node* send1 =
|
||||||
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
|
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
|
b2.opts().WithName("outside_compilation_F1_O1_send"));
|
||||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||||
node_builder1.Input(a).Input(b);
|
node_builder1.Input(a).Input(b);
|
||||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||||
@ -954,7 +1261,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||||
Node* a = Input(b1.opts().WithName("A"));
|
Node* a = InputShaped(b1.opts().WithName("A"));
|
||||||
Node* b = Input(b1.opts().WithName("B"));
|
Node* b = Input(b1.opts().WithName("B"));
|
||||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||||
Node* d =
|
Node* d =
|
||||||
@ -975,6 +1282,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
|||||||
FunctionDefLibrary library_expected;
|
FunctionDefLibrary library_expected;
|
||||||
GraphDef graphdef_expected;
|
GraphDef graphdef_expected;
|
||||||
|
|
||||||
|
TensorShapeProto shape_proto_expected;
|
||||||
|
shape_proto_expected.add_dim()->set_size(2);
|
||||||
|
|
||||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||||
{
|
{
|
||||||
@ -982,17 +1292,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
|||||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||||
{{"F"},
|
{{"F"},
|
||||||
"BinaryTest",
|
"BinaryTest",
|
||||||
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
|
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
|
||||||
{{"outside_compilation_O1_send"},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
"_XlaSendToHost",
|
"_XlaHostCompute",
|
||||||
{},
|
{},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({})}},
|
{{"Tinputs", gtl::ArraySlice<DataType>({})},
|
||||||
|
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"key", "host_compute_channel_F1_O1"},
|
||||||
|
{"shape_inference_graph", ""},
|
||||||
|
{"shapes",
|
||||||
|
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}},
|
||||||
{"D"}},
|
{"D"}},
|
||||||
{{"outside_compilation_O1_recv"},
|
|
||||||
"_XlaRecvFromHost",
|
|
||||||
{},
|
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
|
||||||
{"outside_compilation_O1_send"}},
|
|
||||||
},
|
},
|
||||||
{{"f_0_retval", "F:o:0"}});
|
{{"f_0_retval", "F:o:0"}});
|
||||||
|
|
||||||
@ -1000,14 +1310,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
|||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||||
Node* a = Input(b2.opts().WithName("A"));
|
Node* a = InputShaped(b2.opts().WithName("A"));
|
||||||
Node* b = Input(b2.opts().WithName("B"));
|
Node* b = Input(b2.opts().WithName("B"));
|
||||||
|
|
||||||
Node* recv1 =
|
Node* recv1 =
|
||||||
RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
RecvAtHost("host_compute_channel_F1_O1", {},
|
||||||
|
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
|
Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
|
||||||
Node* send1 = SendFromHost(
|
Node* send1 =
|
||||||
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
|
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
|
b2.opts().WithName("outside_compilation_F1_O1_send"));
|
||||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||||
node_builder1.Input(a).Input(b);
|
node_builder1.Input(a).Input(b);
|
||||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||||
@ -1055,10 +1367,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
|
|||||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||||
{{"F"}, "UnaryTest", {"D:o:0"}},
|
{{"F"}, "UnaryTest", {"D:o:0"}},
|
||||||
{{"outside_compilation_O1_send"},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
"_XlaSendToHost",
|
"_XlaHostCompute",
|
||||||
{"D:o:0"},
|
{"D:o:0"},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"Toutputs", gtl::ArraySlice<DataType>({})},
|
||||||
|
{"key", "host_compute_channel_F1_O1"},
|
||||||
|
{"shape_inference_graph", ""},
|
||||||
|
{"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
|
||||||
},
|
},
|
||||||
{{"f_0_retval", "F:o:0"}});
|
{{"f_0_retval", "F:o:0"}});
|
||||||
|
|
||||||
@ -1069,8 +1385,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
|
|||||||
Node* a = Input(b2.opts().WithName("A"));
|
Node* a = Input(b2.opts().WithName("A"));
|
||||||
Node* b = Input(b2.opts().WithName("B"));
|
Node* b = Input(b2.opts().WithName("B"));
|
||||||
|
|
||||||
Node* recv1 = RecvAtHost(
|
Node* recv1 =
|
||||||
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
|
||||||
|
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
Node* e = Unary(recv1, b2.opts().WithName("E"));
|
Node* e = Unary(recv1, b2.opts().WithName("E"));
|
||||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||||
node_builder1.Input(a).Input(b);
|
node_builder1.Input(a).Input(b);
|
||||||
@ -1118,16 +1435,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
|
|||||||
{
|
{
|
||||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||||
{{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}},
|
{{"F"},
|
||||||
{{"outside_compilation_O1_send"},
|
"UnaryTest",
|
||||||
"_XlaSendToHost",
|
|
||||||
{"D:o:0"},
|
{"D:o:0"},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
|
||||||
{{"outside_compilation_O1_recv"},
|
|
||||||
"_XlaRecvFromHost",
|
|
||||||
{},
|
{},
|
||||||
{{"dtypes", gtl::ArraySlice<DataType>({})}},
|
{"outside_compilation_O1_host_compute"}},
|
||||||
{"outside_compilation_O1_send"}},
|
{{"outside_compilation_O1_host_compute"},
|
||||||
|
"_XlaHostCompute",
|
||||||
|
{"D:o:0"},
|
||||||
|
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"Toutputs", gtl::ArraySlice<DataType>({})},
|
||||||
|
{"key", "host_compute_channel_F1_O1"},
|
||||||
|
{"shape_inference_graph", ""},
|
||||||
|
{"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
|
||||||
},
|
},
|
||||||
{{"f_0_retval", "F:o:0"}});
|
{{"f_0_retval", "F:o:0"}});
|
||||||
|
|
||||||
@ -1138,10 +1458,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
|
|||||||
Node* a = Input(b2.opts().WithName("A"));
|
Node* a = Input(b2.opts().WithName("A"));
|
||||||
Node* b = Input(b2.opts().WithName("B"));
|
Node* b = Input(b2.opts().WithName("B"));
|
||||||
|
|
||||||
Node* recv1 = RecvAtHost(
|
Node* recv1 =
|
||||||
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
|
||||||
|
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
Node* e = Unary(recv1, b2.opts().WithName("E"));
|
Node* e = Unary(recv1, b2.opts().WithName("E"));
|
||||||
Node* send1 = SendFromHost({}, {},
|
Node* send1 = SendFromHost("host_compute_channel_F1_O1", {},
|
||||||
b2.opts()
|
b2.opts()
|
||||||
.WithName("outside_compilation_F1_O1_send")
|
.WithName("outside_compilation_F1_O1_send")
|
||||||
.WithControlInput(e));
|
.WithControlInput(e));
|
||||||
@ -1215,5 +1536,110 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
|
|||||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test for shape inference of outside compilation.
|
||||||
|
TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
|
||||||
|
FunctionDefLibrary library;
|
||||||
|
GraphDef graphdef;
|
||||||
|
|
||||||
|
{
|
||||||
|
*library.add_function() = test::function::XTimesTwo();
|
||||||
|
|
||||||
|
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||||
|
Node* a = InputShaped(b1.opts().WithName("A"));
|
||||||
|
Node* b = Input(b1.opts().WithName("B"));
|
||||||
|
// Give nodes 'c' and 'd' names that collide after lowercasing.
|
||||||
|
Node* c = Unary(a, b1.opts().WithName("C"));
|
||||||
|
Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr(
|
||||||
|
"_encapsulate", "F1"));
|
||||||
|
Node* e = BinaryUnknownShape(c, d,
|
||||||
|
b1.opts()
|
||||||
|
.WithName("E")
|
||||||
|
.WithControlInputs({b, d})
|
||||||
|
.WithAttr("_encapsulate", "F1")
|
||||||
|
.WithAttr("_outside", "O1"));
|
||||||
|
Node* f = Binary(c, e,
|
||||||
|
b1.opts().WithName("F").WithControlInput(e).WithAttr(
|
||||||
|
"_encapsulate", "F1"));
|
||||||
|
Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
|
||||||
|
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||||
|
|
||||||
|
FunctionDefLibrary library_expected;
|
||||||
|
GraphDef graphdef_expected;
|
||||||
|
|
||||||
|
string shape_string_expected;
|
||||||
|
{
|
||||||
|
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
|
||||||
|
Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0"));
|
||||||
|
Node* recv =
|
||||||
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
|
||||||
|
shape.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
|
Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
|
||||||
|
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
|
shape.opts().WithName("outside_compilation_F1_O1_send"));
|
||||||
|
GraphDef shape_graph;
|
||||||
|
TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
|
||||||
|
EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
*library_expected.add_function() = test::function::XTimesTwo();
|
||||||
|
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||||
|
"F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||||
|
{
|
||||||
|
{{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
|
||||||
|
{{"F"},
|
||||||
|
"BinaryTest",
|
||||||
|
{"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"},
|
||||||
|
{},
|
||||||
|
{"outside_compilation_O1_host_compute"}},
|
||||||
|
{{"outside_compilation_O1_host_compute"},
|
||||||
|
"_XlaHostCompute",
|
||||||
|
{"c:o:0"},
|
||||||
|
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||||
|
{"key", "host_compute_channel_F1_O1"},
|
||||||
|
{"shape_inference_graph", shape_string_expected},
|
||||||
|
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||||
|
{"c"}},
|
||||||
|
},
|
||||||
|
{{"f_0_retval", "F:o:0"}});
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||||
|
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||||
|
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||||
|
Node* a = InputShaped(b2.opts().WithName("A"));
|
||||||
|
Node* b = Input(b2.opts().WithName("B"));
|
||||||
|
Node* c = Unary(a, b2.opts().WithName("C"));
|
||||||
|
|
||||||
|
NodeBuilder node_builder("F1", "F1", lib_def.get());
|
||||||
|
node_builder.Input(b).Input(c);
|
||||||
|
Node* call =
|
||||||
|
b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder);
|
||||||
|
|
||||||
|
Node* recv =
|
||||||
|
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
|
||||||
|
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||||
|
Node* e = BinaryUnknownShape(
|
||||||
|
c, ops::NodeOut(recv, 0),
|
||||||
|
b2.opts().WithName("E").WithControlInputs({recv, b}));
|
||||||
|
Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
|
||||||
|
b2.opts()
|
||||||
|
.WithName("outside_compilation_F1_O1_send")
|
||||||
|
.WithControlInput(e));
|
||||||
|
|
||||||
|
Node* s = NoOp(
|
||||||
|
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
|
||||||
|
|
||||||
|
Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
|
||||||
|
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||||
|
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -45,7 +45,7 @@ namespace tensorflow {
|
|||||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
// see comment on `AllowsAsynchronousDeallocation()`.
|
||||||
class XlaAllocator : public xla::DeviceMemoryAllocator {
|
class XlaAllocator : public xla::DeviceMemoryAllocator {
|
||||||
public:
|
public:
|
||||||
XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context);
|
XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context);
|
||||||
~XlaAllocator() override;
|
~XlaAllocator() override;
|
||||||
xla::StatusOr<gpu::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
|
xla::StatusOr<gpu::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
|
||||||
bool retry_on_failure) override;
|
bool retry_on_failure) override;
|
||||||
@ -79,7 +79,8 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
|
|||||||
std::unordered_map<void*, Tensor> tensors_;
|
std::unordered_map<void*, Tensor> tensors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
XlaAllocator::XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context)
|
XlaAllocator::XlaAllocator(const gpu::Platform* platform,
|
||||||
|
OpKernelContext* op_context)
|
||||||
: xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
|
: xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
|
||||||
|
|
||||||
XlaAllocator::~XlaAllocator() = default;
|
XlaAllocator::~XlaAllocator() = default;
|
||||||
@ -248,12 +249,16 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
|
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
|
||||||
|
|
||||||
|
// Builds an XLA allocator for the device.
|
||||||
|
XlaAllocator xla_allocator(client->platform(), ctx);
|
||||||
|
|
||||||
XlaCompiler::Options options;
|
XlaCompiler::Options options;
|
||||||
options.client = client;
|
options.client = client;
|
||||||
options.device_type = &cache->device_type();
|
options.device_type = &cache->device_type();
|
||||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||||
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
|
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
|
||||||
|
options.device_allocator = &xla_allocator;
|
||||||
|
|
||||||
const XlaCompiler::CompilationResult* kernel;
|
const XlaCompiler::CompilationResult* kernel;
|
||||||
xla::LocalExecutable* executable;
|
xla::LocalExecutable* executable;
|
||||||
@ -264,9 +269,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
VLOG(1) << "Executing XLA Computation...";
|
VLOG(1) << "Executing XLA Computation...";
|
||||||
|
|
||||||
// Builds an XLA allocator for the device.
|
|
||||||
XlaAllocator xla_allocator(client->platform(), ctx);
|
|
||||||
|
|
||||||
std::unique_ptr<xla::ShapedBuffer> output;
|
std::unique_ptr<xla::ShapedBuffer> output;
|
||||||
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
|
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
|
||||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
|
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
|
||||||
@ -374,8 +376,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
|||||||
OP_REQUIRES(ctx,
|
OP_REQUIRES(ctx,
|
||||||
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
|
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
|
||||||
errors::Internal("Invalid input index for variable write."));
|
errors::Internal("Invalid input index for variable write."));
|
||||||
TensorShape write_shape;
|
|
||||||
OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape));
|
|
||||||
|
|
||||||
gpu::DeviceMemoryBase buffer = output->buffer({output_num});
|
gpu::DeviceMemoryBase buffer = output->buffer({output_num});
|
||||||
|
|
||||||
@ -397,7 +397,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
// Looks up the owning Tensor by buffer address.
|
// Looks up the owning Tensor by buffer address.
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape,
|
ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape,
|
||||||
variable->tensor()));
|
variable->tensor()));
|
||||||
++output_num;
|
++output_num;
|
||||||
}
|
}
|
||||||
|
@ -148,8 +148,7 @@ Status BuildArguments(int num_constant_args,
|
|||||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||||
arg.kind = XlaCompiler::Argument::kConstant;
|
arg.kind = XlaCompiler::Argument::kConstant;
|
||||||
arg.type = input.dtype();
|
arg.type = input.dtype();
|
||||||
TF_RETURN_IF_ERROR(
|
arg.shape = input.shape();
|
||||||
TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
|
|
||||||
arg.constant_value = input;
|
arg.constant_value = input;
|
||||||
++input_num;
|
++input_num;
|
||||||
}
|
}
|
||||||
@ -170,8 +169,7 @@ Status BuildArguments(int num_constant_args,
|
|||||||
arg.constant_value = input;
|
arg.constant_value = input;
|
||||||
}
|
}
|
||||||
arg.type = input.dtype();
|
arg.type = input.dtype();
|
||||||
TF_RETURN_IF_ERROR(
|
arg.shape = input.shape();
|
||||||
TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
|
|
||||||
++input_num;
|
++input_num;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,8 +187,7 @@ Status BuildArguments(int num_constant_args,
|
|||||||
if (variable_args[variable_id].present) {
|
if (variable_args[variable_id].present) {
|
||||||
const Tensor& value = variable_args[variable_id].value;
|
const Tensor& value = variable_args[variable_id].value;
|
||||||
arg.type = value.dtype();
|
arg.type = value.dtype();
|
||||||
TF_RETURN_IF_ERROR(
|
arg.shape = value.shape();
|
||||||
TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape));
|
|
||||||
arg.initialized = true;
|
arg.initialized = true;
|
||||||
} else {
|
} else {
|
||||||
// The values of uninitialized variables are not passed as inputs, since
|
// The values of uninitialized variables are not passed as inputs, since
|
||||||
@ -199,7 +196,7 @@ Status BuildArguments(int num_constant_args,
|
|||||||
// uninitialized variables.
|
// uninitialized variables.
|
||||||
arg.initialized = false;
|
arg.initialized = false;
|
||||||
arg.type = DT_INVALID;
|
arg.type = DT_INVALID;
|
||||||
arg.shape = xla::Shape();
|
arg.shape = TensorShape();
|
||||||
}
|
}
|
||||||
++input_num;
|
++input_num;
|
||||||
}
|
}
|
||||||
@ -223,6 +220,7 @@ Status XlaCompilationCache::BuildExecutable(
|
|||||||
xla::ExecutableBuildOptions build_options;
|
xla::ExecutableBuildOptions build_options;
|
||||||
build_options.set_device_ordinal(client_->default_device_ordinal());
|
build_options.set_device_ordinal(client_->default_device_ordinal());
|
||||||
build_options.set_result_layout(result.xla_output_shape);
|
build_options.set_result_layout(result.xla_output_shape);
|
||||||
|
build_options.set_device_allocator(options.device_allocator);
|
||||||
|
|
||||||
auto compile_result =
|
auto compile_result =
|
||||||
client_->Compile(*result.computation, argument_layouts, build_options);
|
client_->Compile(*result.computation, argument_layouts, build_options);
|
||||||
|
@ -144,6 +144,21 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "matrix_triangular_solve_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["matrix_triangular_solve_op_test.py"],
|
||||||
|
tags = ["optonly"],
|
||||||
|
deps = [
|
||||||
|
":xla_test",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "clustering_test",
|
name = "clustering_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -240,6 +255,18 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "extract_image_patches_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["extract_image_patches_op_test.py"],
|
||||||
|
deps = [
|
||||||
|
":xla_test",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "fft_test",
|
name = "fft_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
@ -326,6 +353,19 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "matrix_band_part_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["matrix_band_part_test.py"],
|
||||||
|
tags = ["optonly"],
|
||||||
|
deps = [
|
||||||
|
":xla_test",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "momentum_test",
|
name = "momentum_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -437,6 +477,18 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "reverse_sequence_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["reverse_sequence_op_test.py"],
|
||||||
|
deps = [
|
||||||
|
":xla_test",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "rmsprop_test",
|
name = "rmsprop_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -1181,6 +1181,50 @@ class BinaryOpsTest(XLATestCase):
|
|||||||
np.array([4, 5, 6], dtype=np.int32),
|
np.array([4, 5, 6], dtype=np.int32),
|
||||||
expected=None)
|
expected=None)
|
||||||
|
|
||||||
|
def testMatrixSetDiag(self):
|
||||||
|
for dtype in self.numeric_types:
|
||||||
|
# Square
|
||||||
|
self._testBinary(
|
||||||
|
array_ops.matrix_set_diag,
|
||||||
|
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
|
dtype=dtype),
|
||||||
|
np.array([1.0, 2.0, 3.0], dtype=dtype),
|
||||||
|
expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]],
|
||||||
|
dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
array_ops.matrix_set_diag,
|
||||||
|
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
|
||||||
|
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
|
||||||
|
dtype=dtype),
|
||||||
|
np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype),
|
||||||
|
expected=np.array(
|
||||||
|
[[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]],
|
||||||
|
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]],
|
||||||
|
dtype=dtype))
|
||||||
|
|
||||||
|
# Rectangular
|
||||||
|
self._testBinary(
|
||||||
|
array_ops.matrix_set_diag,
|
||||||
|
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
|
||||||
|
np.array([3.0, 4.0], dtype=dtype),
|
||||||
|
expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
array_ops.matrix_set_diag,
|
||||||
|
np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
|
||||||
|
np.array([3.0, 4.0], dtype=dtype),
|
||||||
|
expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
array_ops.matrix_set_diag,
|
||||||
|
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
|
||||||
|
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
|
||||||
|
np.array([[-1.0, -2.0], [-4.0, -5.0]],
|
||||||
|
dtype=dtype),
|
||||||
|
expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
|
||||||
|
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
|
||||||
|
dtype=dtype))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
134
tensorflow/compiler/tests/extract_image_patches_op_test.py
Normal file
134
tensorflow/compiler/tests/extract_image_patches_op_test.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Functional tests for ExtractImagePatches op."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractImagePatches(XLATestCase):
|
||||||
|
"""Functional tests for ExtractImagePatches op."""
|
||||||
|
|
||||||
|
def _VerifyValues(self, image, ksizes, strides, rates, padding, patches):
|
||||||
|
"""Tests input-output pairs for the ExtractImagePatches op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input tensor with shape: [batch, in_rows, in_cols, depth].
|
||||||
|
ksizes: Patch size specified as: [ksize_rows, ksize_cols].
|
||||||
|
strides: Output strides, specified as [stride_rows, stride_cols].
|
||||||
|
rates: Atrous rates, specified as [rate_rows, rate_cols].
|
||||||
|
padding: Padding type.
|
||||||
|
patches: Expected output.
|
||||||
|
"""
|
||||||
|
ksizes = [1] + ksizes + [1]
|
||||||
|
strides = [1] + strides + [1]
|
||||||
|
rates = [1] + rates + [1]
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
image_placeholder = array_ops.placeholder(dtypes.float32)
|
||||||
|
with self.test_scope():
|
||||||
|
out_tensor = array_ops.extract_image_patches(
|
||||||
|
image_placeholder,
|
||||||
|
ksizes=ksizes,
|
||||||
|
strides=strides,
|
||||||
|
rates=rates,
|
||||||
|
padding=padding,
|
||||||
|
name="im2col")
|
||||||
|
feed_dict = {image_placeholder: image}
|
||||||
|
self.assertAllClose(patches, out_tensor.eval(feed_dict=feed_dict))
|
||||||
|
|
||||||
|
def testKsize1x1Stride1x1Rate1x1(self):
|
||||||
|
"""Verifies that for 1x1 kernel the output equals the input."""
|
||||||
|
# [2, 3, 4, 5]
|
||||||
|
image = np.reshape(range(120), [2, 3, 4, 5])
|
||||||
|
# [2, 3, 4, 5]
|
||||||
|
patches = np.reshape(range(120), [2, 3, 4, 5])
|
||||||
|
for padding in ["VALID", "SAME"]:
|
||||||
|
self._VerifyValues(
|
||||||
|
image,
|
||||||
|
ksizes=[1, 1],
|
||||||
|
strides=[1, 1],
|
||||||
|
rates=[1, 1],
|
||||||
|
padding=padding,
|
||||||
|
patches=patches)
|
||||||
|
|
||||||
|
def testKsize1x1Stride2x3Rate1x1(self):
|
||||||
|
"""Test for 1x1 kernel and strides."""
|
||||||
|
# [2, 4, 5, 3]
|
||||||
|
image = np.reshape(range(120), [2, 4, 5, 3])
|
||||||
|
# [2, 2, 2, 3]
|
||||||
|
patches = image[:, ::2, ::3, :]
|
||||||
|
for padding in ["VALID", "SAME"]:
|
||||||
|
self._VerifyValues(
|
||||||
|
image,
|
||||||
|
ksizes=[1, 1],
|
||||||
|
strides=[2, 3],
|
||||||
|
rates=[1, 1],
|
||||||
|
padding=padding,
|
||||||
|
patches=patches)
|
||||||
|
|
||||||
|
def testKsize2x2Stride1x1Rate1x1Valid(self):
|
||||||
|
"""Test for 2x2 kernel with VALID padding."""
|
||||||
|
# [1, 2, 2, 1]
|
||||||
|
image = [[[[1], [2]], [[3], [4]]]]
|
||||||
|
# [1, 1, 1, 4]
|
||||||
|
patches = [[[[1, 2, 3, 4]]]]
|
||||||
|
self._VerifyValues(
|
||||||
|
image,
|
||||||
|
ksizes=[2, 2],
|
||||||
|
strides=[1, 1],
|
||||||
|
rates=[1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
patches=patches)
|
||||||
|
|
||||||
|
def testKsize2x2Stride1x1Rate1x1Same(self):
|
||||||
|
"""Test for 2x2 kernel with SAME padding."""
|
||||||
|
# [1, 2, 2, 1]
|
||||||
|
image = [[[[1], [2]], [[3], [4]]]]
|
||||||
|
# [1, 2, 2, 4]
|
||||||
|
patches = [[[[1, 2, 3, 4], [2, 0, 4, 0]], [[3, 4, 0, 0], [4, 0, 0, 0]]]]
|
||||||
|
self._VerifyValues(
|
||||||
|
image,
|
||||||
|
ksizes=[2, 2],
|
||||||
|
strides=[1, 1],
|
||||||
|
rates=[1, 1],
|
||||||
|
padding="SAME",
|
||||||
|
patches=patches)
|
||||||
|
|
||||||
|
def testKsize2x2Stride1x1Rate2x2Valid(self):
|
||||||
|
"""Test for 2x2 kernel with 2x2 dilation."""
|
||||||
|
# [1, 2, 2, 1]
|
||||||
|
image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32)
|
||||||
|
# [1, 2, 2, 4]
|
||||||
|
patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]],
|
||||||
|
[[4, 6, 12, 14], [5, 7, 13, 15]]]]
|
||||||
|
self._VerifyValues(
|
||||||
|
image,
|
||||||
|
ksizes=[2, 2],
|
||||||
|
strides=[1, 1],
|
||||||
|
rates=[2, 2],
|
||||||
|
padding="VALID",
|
||||||
|
patches=patches)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
64
tensorflow/compiler/tests/matrix_band_part_test.py
Normal file
64
tensorflow/compiler/tests/matrix_band_part_test.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixBandPartTest(XLATestCase):
|
||||||
|
|
||||||
|
def _testMatrixBandPart(self, dtype, shape):
|
||||||
|
with self.test_session():
|
||||||
|
batch_shape = shape[:-2]
|
||||||
|
mat = np.ones(shape).astype(dtype)
|
||||||
|
batch_mat = np.tile(mat, batch_shape + [1, 1])
|
||||||
|
for lower in -1, 0, 1, shape[-2] - 1:
|
||||||
|
for upper in -1, 0, 1, shape[-1] - 1:
|
||||||
|
band_np = mat
|
||||||
|
if lower >= 0:
|
||||||
|
band_np = np.triu(band_np, -lower)
|
||||||
|
if upper >= 0:
|
||||||
|
band_np = np.tril(band_np, upper)
|
||||||
|
if batch_shape:
|
||||||
|
band_np = np.tile(band_np, batch_shape + [1, 1])
|
||||||
|
|
||||||
|
placeholder = array_ops.placeholder(dtype)
|
||||||
|
with self.test_scope():
|
||||||
|
band = array_ops.matrix_band_part(
|
||||||
|
placeholder,
|
||||||
|
constant_op.constant(lower, dtype=dtypes.int32),
|
||||||
|
constant_op.constant(upper, dtype=dtypes.int32))
|
||||||
|
feed_dict = {placeholder: batch_mat}
|
||||||
|
self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
|
||||||
|
|
||||||
|
def testMatrixBandPart(self):
|
||||||
|
for dtype in self.float_types:
|
||||||
|
for batch_shape in [[], [2,], [1, 3, 2]]:
|
||||||
|
for rows in 1, 2, 7:
|
||||||
|
for cols in 1, 2, 7:
|
||||||
|
self._testMatrixBandPart(dtype, batch_shape + [rows, cols])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
130
tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
Normal file
130
tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.ops.tf.MatrixTriangularSolve."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import linalg_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
def MakePlaceholder(x):
|
||||||
|
return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixTriangularSolveOpTest(XLATestCase):
|
||||||
|
|
||||||
|
def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca,
|
||||||
|
placeholder_b, a, clean_a, b, verification,
|
||||||
|
atol):
|
||||||
|
feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b}
|
||||||
|
verification_np = sess.run(verification, feed_dict)
|
||||||
|
self.assertAllClose(b, verification_np, atol=atol)
|
||||||
|
|
||||||
|
def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol):
|
||||||
|
clean_a = np.tril(a) if lower else np.triu(a)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
placeholder_a = MakePlaceholder(a)
|
||||||
|
placeholder_ca = MakePlaceholder(clean_a)
|
||||||
|
placeholder_b = MakePlaceholder(b)
|
||||||
|
with self.test_scope():
|
||||||
|
x = linalg_ops.matrix_triangular_solve(
|
||||||
|
placeholder_a, placeholder_b, lower=lower, adjoint=adjoint)
|
||||||
|
verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint)
|
||||||
|
self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca,
|
||||||
|
placeholder_b, a, clean_a, b,
|
||||||
|
verification, atol)
|
||||||
|
|
||||||
|
def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4):
|
||||||
|
transp = lambda x: np.swapaxes(x, -1, -2)
|
||||||
|
for lower, adjoint in itertools.product([True, False], repeat=2):
|
||||||
|
self._VerifyTriangularSolve(
|
||||||
|
a if lower else transp(a), b, lower, adjoint, atol)
|
||||||
|
|
||||||
|
def testBasic(self):
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
a = np.tril(rng.randn(5, 5))
|
||||||
|
b = rng.randn(5, 7)
|
||||||
|
for dtype in self.float_types:
|
||||||
|
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
|
||||||
|
|
||||||
|
def testBasicNotActuallyTriangular(self):
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
a = rng.randn(5, 5) # the `a` matrix is not lower-triangular
|
||||||
|
b = rng.randn(5, 7)
|
||||||
|
for dtype in self.float_types:
|
||||||
|
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
|
||||||
|
|
||||||
|
def testBasicComplexDtypes(self):
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j)
|
||||||
|
b = rng.randn(5, 7) + rng.randn(5, 7) * 1j
|
||||||
|
for dtype in self.complex_types:
|
||||||
|
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
|
||||||
|
|
||||||
|
def testBatch(self):
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)),
|
||||||
|
((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))]
|
||||||
|
tuples = itertools.product(self.float_types, shapes)
|
||||||
|
for dtype, (a_shape, b_shape) in tuples:
|
||||||
|
n = a_shape[-1]
|
||||||
|
a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n)
|
||||||
|
b = rng.randn(*b_shape)
|
||||||
|
self._VerifyTriangularSolveCombo(
|
||||||
|
a.astype(dtype), b.astype(dtype), atol=1e-3)
|
||||||
|
|
||||||
|
def testLarge(self):
|
||||||
|
n = 1024
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n)
|
||||||
|
b = rng.randn(n, n)
|
||||||
|
self._VerifyTriangularSolve(
|
||||||
|
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
|
||||||
|
|
||||||
|
def testNonSquareCoefficientMatrix(self):
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
for dtype in self.float_types:
|
||||||
|
a = rng.randn(3, 4).astype(dtype)
|
||||||
|
b = rng.randn(4, 4).astype(dtype)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
linalg_ops.matrix_triangular_solve(a, b)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
linalg_ops.matrix_triangular_solve(a, b)
|
||||||
|
|
||||||
|
def testWrongDimensions(self):
|
||||||
|
randn = np.random.RandomState(0).randn
|
||||||
|
for dtype in self.float_types:
|
||||||
|
lhs = constant_op.constant(randn(3, 3), dtype=dtype)
|
||||||
|
rhs = constant_op.constant(randn(4, 3), dtype=dtype)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
linalg_ops.matrix_triangular_solve(lhs, rhs)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
linalg_ops.matrix_triangular_solve(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
93
tensorflow/compiler/tests/reverse_sequence_op_test.py
Normal file
93
tensorflow/compiler/tests/reverse_sequence_op_test.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.ops.reverse_sequence_op."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ReverseSequenceTest(XLATestCase):
|
||||||
|
|
||||||
|
def _testReverseSequence(self,
|
||||||
|
x,
|
||||||
|
batch_axis,
|
||||||
|
seq_axis,
|
||||||
|
seq_lengths,
|
||||||
|
truth,
|
||||||
|
expected_err_re=None):
|
||||||
|
with self.test_session():
|
||||||
|
p = array_ops.placeholder(dtypes.as_dtype(x.dtype))
|
||||||
|
lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype))
|
||||||
|
with self.test_scope():
|
||||||
|
ans = array_ops.reverse_sequence(
|
||||||
|
p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths)
|
||||||
|
if expected_err_re is None:
|
||||||
|
tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths})
|
||||||
|
self.assertAllClose(tf_ans, truth, atol=1e-10)
|
||||||
|
else:
|
||||||
|
with self.assertRaisesOpError(expected_err_re):
|
||||||
|
ans.eval(feed_dict={p: x, lengths: seq_lengths})
|
||||||
|
|
||||||
|
def testSimple(self):
|
||||||
|
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
|
||||||
|
expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32)
|
||||||
|
self._testReverseSequence(
|
||||||
|
x,
|
||||||
|
batch_axis=0,
|
||||||
|
seq_axis=1,
|
||||||
|
seq_lengths=np.array([1, 3, 2], np.int32),
|
||||||
|
truth=expected)
|
||||||
|
|
||||||
|
def _testBasic(self, dtype, len_dtype):
|
||||||
|
x = np.asarray(
|
||||||
|
[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]],
|
||||||
|
[[17, 18, 19, 20], [21, 22, 23, 24]]],
|
||||||
|
dtype=dtype)
|
||||||
|
x = x.reshape(3, 2, 4, 1, 1)
|
||||||
|
x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
|
||||||
|
|
||||||
|
# reverse dim 2 up to (0:3, none, 0:4) along dim=0
|
||||||
|
seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype)
|
||||||
|
|
||||||
|
truth_orig = np.asarray(
|
||||||
|
[
|
||||||
|
[[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3
|
||||||
|
[[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none
|
||||||
|
[[20, 19, 18, 17], [24, 23, 22, 21]]
|
||||||
|
], # reverse 0:4 (all)
|
||||||
|
dtype=dtype)
|
||||||
|
truth_orig = truth_orig.reshape(3, 2, 4, 1, 1)
|
||||||
|
truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
|
||||||
|
|
||||||
|
seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.)
|
||||||
|
batch_axis = 2
|
||||||
|
self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth)
|
||||||
|
|
||||||
|
def testSeqLength(self):
|
||||||
|
for dtype in self.all_types:
|
||||||
|
for seq_dtype in self.int_types:
|
||||||
|
self._testBasic(dtype, seq_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -154,6 +154,21 @@ class UnaryOpsTest(XLATestCase):
|
|||||||
|
|
||||||
def testFloatOps(self):
|
def testFloatOps(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
|
x = np.arange(-0.90, 0.90, 0.25)
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.acos,
|
||||||
|
x.astype(dtype),
|
||||||
|
expected=np.arccos(x).astype(dtype))
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.asin,
|
||||||
|
x.astype(dtype),
|
||||||
|
expected=np.arcsin(x).astype(dtype))
|
||||||
|
x = np.arange(-3, 3).reshape(1, 3, 2)
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.atan,
|
||||||
|
x.astype(dtype),
|
||||||
|
expected=np.arctan(x).astype(dtype))
|
||||||
|
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
math_ops.acosh,
|
math_ops.acosh,
|
||||||
np.array([1, 2, 3, 4], dtype=dtype),
|
np.array([1, 2, 3, 4], dtype=dtype),
|
||||||
|
@ -427,16 +427,36 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
|
|||||||
// identity nodes are values used by the loop body or condition.
|
// identity nodes are values used by the loop body or condition.
|
||||||
// The Identity node may have the wrong device so copy the device from
|
// The Identity node may have the wrong device so copy the device from
|
||||||
// one of its outputs instead.
|
// one of its outputs instead.
|
||||||
|
std::deque<const Edge*> possible_exit;
|
||||||
for (const Edge* edge : arg.switch_node->out_edges()) {
|
for (const Edge* edge : arg.switch_node->out_edges()) {
|
||||||
if (edge->src_output() == 0 && IsExit(edge->dst())) {
|
if (edge->src_output() == 0) {
|
||||||
|
possible_exit.push_back(edge);
|
||||||
|
}
|
||||||
|
if (IsIdentity(edge->dst())) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO(b/67425339): Allow general graph between switch and exit.
|
||||||
|
while (!possible_exit.empty()) {
|
||||||
|
const Edge* edge = possible_exit.front();
|
||||||
|
possible_exit.pop_front();
|
||||||
|
if (IsExit(edge->dst())) {
|
||||||
if (arg.exit != nullptr) {
|
if (arg.exit != nullptr) {
|
||||||
return errors::InvalidArgument("Duplicate Exit successors to ",
|
return errors::InvalidArgument("Duplicate Exit successors to ",
|
||||||
arg.switch_node->name());
|
arg.switch_node->name());
|
||||||
}
|
}
|
||||||
arg.exit = edge->dst();
|
arg.exit = edge->dst();
|
||||||
} else if (StringPiece(edge->dst()->type_string()) == "Identity") {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(
|
if (!IsIdentity(edge->dst())) {
|
||||||
SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
|
return errors::Unimplemented("General graph between switch (",
|
||||||
|
arg.switch_node->name(),
|
||||||
|
") and exit node of frame ",
|
||||||
|
frame->name, " not supported yet.");
|
||||||
|
}
|
||||||
|
for (const Edge* out : edge->dst()->out_edges()) {
|
||||||
|
possible_exit.push_back(out);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,9 @@ Operator | Type Constraint
|
|||||||
`Acosh` | `T={complex64,double,float}`
|
`Acosh` | `T={complex64,double,float}`
|
||||||
`Add` | `T={complex64,double,float,int32,int64}`
|
`Add` | `T={complex64,double,float,int32,int64}`
|
||||||
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`AdjustContrastv2` |
|
||||||
|
`AdjustHue` |
|
||||||
|
`AdjustSaturation` |
|
||||||
`All` | `Tidx={int32,int64}`
|
`All` | `Tidx={int32,int64}`
|
||||||
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
`Any` | `Tidx={int32,int64}`
|
`Any` | `Tidx={int32,int64}`
|
||||||
@ -34,7 +37,7 @@ Operator | Type Constraint
|
|||||||
`BroadcastGradientArgs` | `T={int32,int64}`
|
`BroadcastGradientArgs` | `T={int32,int64}`
|
||||||
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Ceil` | `T={double,float}`
|
`Ceil` | `T={double,float}`
|
||||||
`Cholesky` | `T={complex64,double,float}`
|
`Cholesky` | `T={double,float}`
|
||||||
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
||||||
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
@ -68,7 +71,11 @@ Operator | Type Constraint
|
|||||||
`Exp` | `T={complex64,double,float}`
|
`Exp` | `T={complex64,double,float}`
|
||||||
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Expm1` | `T={complex64,double,float}`
|
`Expm1` | `T={complex64,double,float}`
|
||||||
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`FFT` |
|
||||||
|
`FFT2D` |
|
||||||
|
`FFT3D` |
|
||||||
|
`Fill` | `index_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Floor` | `T={double,float}`
|
`Floor` | `T={double,float}`
|
||||||
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
||||||
`FloorMod` | `T={double,float,int32,int64}`
|
`FloorMod` | `T={double,float,int32,int64}`
|
||||||
@ -80,6 +87,13 @@ Operator | Type Constraint
|
|||||||
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`HSVToRGB` | `T={double,float}`
|
||||||
|
`IFFT` |
|
||||||
|
`IFFT2D` |
|
||||||
|
`IFFT3D` |
|
||||||
|
`IRFFT` |
|
||||||
|
`IRFFT2D` |
|
||||||
|
`IRFFT3D` |
|
||||||
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
@ -105,11 +119,14 @@ Operator | Type Constraint
|
|||||||
`MatMul` | `T={complex64,double,float}`
|
`MatMul` | `T={complex64,double,float}`
|
||||||
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MatrixTriangularSolve` | `T={complex64,double,float}`
|
||||||
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`MaxPool` | `T={double,float,int32,int64}`
|
`MaxPool` | `T={double,float,int32,int64}`
|
||||||
`MaxPool3D` | `T={float}`
|
`MaxPool3D` | `T={float}`
|
||||||
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
||||||
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MaxPoolV2` | `T={double,float,int32,int64}`
|
||||||
`Maximum` | `T={double,float,int32,int64}`
|
`Maximum` | `T={double,float,int32,int64}`
|
||||||
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
@ -131,6 +148,10 @@ Operator | Type Constraint
|
|||||||
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`QuantizeAndDequantizeV2` | `T={double,float}`
|
`QuantizeAndDequantizeV2` | `T={double,float}`
|
||||||
|
`RFFT` |
|
||||||
|
`RFFT2D` |
|
||||||
|
`RFFT3D` |
|
||||||
|
`RGBToHSV` | `T={double,float}`
|
||||||
`RandomStandardNormal` | `dtype={float}`
|
`RandomStandardNormal` | `dtype={float}`
|
||||||
`RandomUniform` | `T={int32,int64}`<br>`dtype={double,float}`
|
`RandomUniform` | `T={int32,int64}`<br>`dtype={double,float}`
|
||||||
`RandomUniformInt` | `T={int32,int64}`<br>`Tout={int32,int64}`
|
`RandomUniformInt` | `T={int32,int64}`<br>`Tout={int32,int64}`
|
||||||
@ -146,6 +167,8 @@ Operator | Type Constraint
|
|||||||
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ResizeBilinear` | `T={double,float,int32,int64}`
|
||||||
|
`ResizeBilinearGrad` | `T={double,float}`
|
||||||
`ResourceApplyAdagrad` | `T={double,float}`
|
`ResourceApplyAdagrad` | `T={double,float}`
|
||||||
`ResourceApplyAdam` | `T={double,float}`
|
`ResourceApplyAdam` | `T={double,float}`
|
||||||
`ResourceApplyFtrl` | `T={double,float}`
|
`ResourceApplyFtrl` | `T={double,float}`
|
||||||
@ -156,6 +179,7 @@ Operator | Type Constraint
|
|||||||
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
||||||
`RightShift` | `T={int32,int64,uint32,uint64}`
|
`RightShift` | `T={int32,int64,uint32,uint64}`
|
||||||
`Rint` | `T={double,float}`
|
`Rint` | `T={double,float}`
|
||||||
|
@ -6,6 +6,9 @@ Operator | Type Constraint
|
|||||||
`Acosh` | `T={complex64,double,float}`
|
`Acosh` | `T={complex64,double,float}`
|
||||||
`Add` | `T={complex64,double,float,int32,int64}`
|
`Add` | `T={complex64,double,float,int32,int64}`
|
||||||
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`AdjustContrastv2` |
|
||||||
|
`AdjustHue` |
|
||||||
|
`AdjustSaturation` |
|
||||||
`All` | `Tidx={int32,int64}`
|
`All` | `Tidx={int32,int64}`
|
||||||
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
`Any` | `Tidx={int32,int64}`
|
`Any` | `Tidx={int32,int64}`
|
||||||
@ -34,7 +37,7 @@ Operator | Type Constraint
|
|||||||
`BroadcastGradientArgs` | `T={int32,int64}`
|
`BroadcastGradientArgs` | `T={int32,int64}`
|
||||||
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Ceil` | `T={double,float}`
|
`Ceil` | `T={double,float}`
|
||||||
`Cholesky` | `T={complex64,double,float}`
|
`Cholesky` | `T={double,float}`
|
||||||
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
||||||
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
@ -68,7 +71,11 @@ Operator | Type Constraint
|
|||||||
`Exp` | `T={complex64,double,float}`
|
`Exp` | `T={complex64,double,float}`
|
||||||
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Expm1` | `T={complex64,double,float}`
|
`Expm1` | `T={complex64,double,float}`
|
||||||
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`FFT` |
|
||||||
|
`FFT2D` |
|
||||||
|
`FFT3D` |
|
||||||
|
`Fill` | `index_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Floor` | `T={double,float}`
|
`Floor` | `T={double,float}`
|
||||||
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
||||||
`FloorMod` | `T={double,float,int32,int64}`
|
`FloorMod` | `T={double,float,int32,int64}`
|
||||||
@ -80,6 +87,13 @@ Operator | Type Constraint
|
|||||||
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`HSVToRGB` | `T={double,float}`
|
||||||
|
`IFFT` |
|
||||||
|
`IFFT2D` |
|
||||||
|
`IFFT3D` |
|
||||||
|
`IRFFT` |
|
||||||
|
`IRFFT2D` |
|
||||||
|
`IRFFT3D` |
|
||||||
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
@ -105,11 +119,14 @@ Operator | Type Constraint
|
|||||||
`MatMul` | `T={complex64,double,float}`
|
`MatMul` | `T={complex64,double,float}`
|
||||||
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MatrixTriangularSolve` | `T={complex64,double,float}`
|
||||||
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`MaxPool` | `T={double,float,int32,int64}`
|
`MaxPool` | `T={double,float,int32,int64}`
|
||||||
`MaxPool3D` | `T={float}`
|
`MaxPool3D` | `T={float}`
|
||||||
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
||||||
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MaxPoolV2` | `T={double,float,int32,int64}`
|
||||||
`Maximum` | `T={double,float,int32,int64}`
|
`Maximum` | `T={double,float,int32,int64}`
|
||||||
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
@ -131,6 +148,10 @@ Operator | Type Constraint
|
|||||||
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`QuantizeAndDequantizeV2` | `T={double,float}`
|
`QuantizeAndDequantizeV2` | `T={double,float}`
|
||||||
|
`RFFT` |
|
||||||
|
`RFFT2D` |
|
||||||
|
`RFFT3D` |
|
||||||
|
`RGBToHSV` | `T={double,float}`
|
||||||
`Range` | `Tidx={double,float,int32,int64}`
|
`Range` | `Tidx={double,float,int32,int64}`
|
||||||
`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
@ -143,6 +164,8 @@ Operator | Type Constraint
|
|||||||
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ResizeBilinear` | `T={double,float,int32,int64}`
|
||||||
|
`ResizeBilinearGrad` | `T={double,float}`
|
||||||
`ResourceApplyAdagrad` | `T={double,float}`
|
`ResourceApplyAdagrad` | `T={double,float}`
|
||||||
`ResourceApplyAdam` | `T={double,float}`
|
`ResourceApplyAdam` | `T={double,float}`
|
||||||
`ResourceApplyFtrl` | `T={double,float}`
|
`ResourceApplyFtrl` | `T={double,float}`
|
||||||
@ -153,6 +176,7 @@ Operator | Type Constraint
|
|||||||
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
||||||
`RightShift` | `T={int32,int64,uint32,uint64}`
|
`RightShift` | `T={int32,int64,uint32,uint64}`
|
||||||
`Rint` | `T={double,float}`
|
`Rint` | `T={double,float}`
|
||||||
|
@ -60,9 +60,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
|||||||
for (int i = 0; i < args->size(); ++i) {
|
for (int i = 0; i < args->size(); ++i) {
|
||||||
XlaCompiler::Argument& arg = (*args)[i];
|
XlaCompiler::Argument& arg = (*args)[i];
|
||||||
arg.type = ctx->input_type(i);
|
arg.type = ctx->input_type(i);
|
||||||
|
arg.shape = ctx->InputShape(i);
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
|
|
||||||
|
|
||||||
if (arg.type == DT_RESOURCE) {
|
if (arg.type == DT_RESOURCE) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
|
@ -31,6 +31,7 @@ tf_kernel_library(
|
|||||||
"diag_op.cc",
|
"diag_op.cc",
|
||||||
"dynamic_stitch_op.cc",
|
"dynamic_stitch_op.cc",
|
||||||
"elu_op.cc",
|
"elu_op.cc",
|
||||||
|
"extract_image_patches_op.cc",
|
||||||
"fft_ops.cc",
|
"fft_ops.cc",
|
||||||
"fill_op.cc",
|
"fill_op.cc",
|
||||||
"function_ops.cc",
|
"function_ops.cc",
|
||||||
@ -43,6 +44,9 @@ tf_kernel_library(
|
|||||||
"l2loss_op.cc",
|
"l2loss_op.cc",
|
||||||
"lrn_ops.cc",
|
"lrn_ops.cc",
|
||||||
"matmul_op.cc",
|
"matmul_op.cc",
|
||||||
|
"matrix_band_part_op.cc",
|
||||||
|
"matrix_set_diag_op.cc",
|
||||||
|
"matrix_triangular_solve_op.cc",
|
||||||
"mirror_pad_op.cc",
|
"mirror_pad_op.cc",
|
||||||
"no_op.cc",
|
"no_op.cc",
|
||||||
"one_hot_op.cc",
|
"one_hot_op.cc",
|
||||||
@ -58,6 +62,7 @@ tf_kernel_library(
|
|||||||
"reshape_op.cc",
|
"reshape_op.cc",
|
||||||
"retval_op.cc",
|
"retval_op.cc",
|
||||||
"reverse_op.cc",
|
"reverse_op.cc",
|
||||||
|
"reverse_sequence_op.cc",
|
||||||
"scan_ops.cc",
|
"scan_ops.cc",
|
||||||
"segment_reduction_ops.cc",
|
"segment_reduction_ops.cc",
|
||||||
"select_op.cc",
|
"select_op.cc",
|
||||||
@ -92,6 +97,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/tf2xla/lib:batch_dot",
|
"//tensorflow/compiler/tf2xla/lib:batch_dot",
|
||||||
"//tensorflow/compiler/tf2xla/lib:cholesky",
|
"//tensorflow/compiler/tf2xla/lib:cholesky",
|
||||||
|
"//tensorflow/compiler/tf2xla/lib:triangular_solve",
|
||||||
"//tensorflow/compiler/tf2xla/lib:util",
|
"//tensorflow/compiler/tf2xla/lib:util",
|
||||||
"//tensorflow/compiler/tf2xla/ops:sendrecv_ops",
|
"//tensorflow/compiler/tf2xla/ops:sendrecv_ops",
|
||||||
"//tensorflow/compiler/xla:array4d",
|
"//tensorflow/compiler/xla:array4d",
|
||||||
|
@ -28,8 +28,9 @@ class BatchMatMulOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
auto result =
|
auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1),
|
||||||
BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_);
|
/*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_,
|
||||||
|
/*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_);
|
||||||
OP_REQUIRES_OK(ctx, result.status());
|
OP_REQUIRES_OK(ctx, result.status());
|
||||||
ctx->SetOutput(0, result.ValueOrDie());
|
ctx->SetOutput(0, result.ValueOrDie());
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ class CholeskyOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp);
|
REGISTER_XLA_OP(Name("Cholesky").TypeConstraint("T", kFloatTypes), CholeskyOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
169
tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
Normal file
169
tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class ExtractImagePatchesOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const TensorFormat data_format = FORMAT_NHWC;
|
||||||
|
const int num_dims = ksizes_.size();
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, num_dims >= 3,
|
||||||
|
errors::InvalidArgument("Kernel size must have at least 3 dimensions"));
|
||||||
|
const int num_spatial_dims = num_dims - 2;
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, strides_.size() == num_dims,
|
||||||
|
errors::InvalidArgument("Sliding window strides field must "
|
||||||
|
"specify ",
|
||||||
|
num_dims, " dimensions"));
|
||||||
|
OP_REQUIRES(ctx, dilations_.size() == num_dims,
|
||||||
|
errors::InvalidArgument("Dilations field must "
|
||||||
|
"specify ",
|
||||||
|
num_dims, " dimensions"));
|
||||||
|
|
||||||
|
int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
|
||||||
|
int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format);
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
|
||||||
|
errors::Unimplemented("Current implementation does not yet support "
|
||||||
|
"kernel sizes > 1 in the batch and depth "
|
||||||
|
"dimensions."));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
|
||||||
|
errors::Unimplemented("Current implementation does not yet support "
|
||||||
|
"strides in the batch and depth dimensions."));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
|
||||||
|
errors::Unimplemented("Current implementation does not support "
|
||||||
|
"dilations in the batch and depth dimensions."));
|
||||||
|
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, ksizes_[input_dim] >= 0,
|
||||||
|
errors::Unimplemented("Kernel size values must be non-negative; ", i,
|
||||||
|
"th spatial dimension had dilation ",
|
||||||
|
dilations_[input_dim]));
|
||||||
|
OP_REQUIRES(ctx, strides_[input_dim] >= 1,
|
||||||
|
errors::Unimplemented("Stride values must be positive; ", i,
|
||||||
|
"th spatial dimension had dilation ",
|
||||||
|
dilations_[input_dim]));
|
||||||
|
OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
|
||||||
|
errors::Unimplemented("Dilation values must be positive; ", i,
|
||||||
|
"th spatial dimension had dilation ",
|
||||||
|
dilations_[input_dim]));
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::PrimitiveType type;
|
||||||
|
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type));
|
||||||
|
|
||||||
|
const TensorShape input_shape = ctx->InputShape(0);
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_shape.dims() == num_dims,
|
||||||
|
errors::InvalidArgument("input must be ", num_dims, "-dimensional",
|
||||||
|
input_shape.DebugString()));
|
||||||
|
const int64 depth = input_shape.dim_size(feature_dim);
|
||||||
|
|
||||||
|
xla::ComputationBuilder* builder = ctx->builder();
|
||||||
|
|
||||||
|
// The following code is equivalent to:
|
||||||
|
// eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
|
||||||
|
int64 kernel_size = 1;
|
||||||
|
std::vector<int64> lhs_shape(num_dims, 1);
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
|
||||||
|
lhs_shape[i] = ksizes_[input_dim];
|
||||||
|
kernel_size *= ksizes_[input_dim];
|
||||||
|
}
|
||||||
|
lhs_shape[num_spatial_dims] = depth;
|
||||||
|
lhs_shape[num_spatial_dims + 1] = 1;
|
||||||
|
|
||||||
|
// Builds an identity matrix as a broadcast equality of iotas.
|
||||||
|
// iota = np.arange(np.prod(ksize), depth)
|
||||||
|
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
|
||||||
|
xla::ComputationDataHandle iota;
|
||||||
|
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
|
||||||
|
kernel_size * depth, &iota));
|
||||||
|
|
||||||
|
auto lhs = builder->Reshape(iota, lhs_shape);
|
||||||
|
auto filter = builder->ConvertElementType(
|
||||||
|
builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
|
||||||
|
|
||||||
|
xla::ConvolutionDimensionNumbers dims;
|
||||||
|
std::vector<int64> window_strides(num_spatial_dims);
|
||||||
|
std::vector<int64> lhs_dilation(num_spatial_dims, 1);
|
||||||
|
std::vector<int64> rhs_dilation(num_spatial_dims);
|
||||||
|
std::vector<std::pair<int64, int64>> padding(num_spatial_dims);
|
||||||
|
|
||||||
|
dims.set_input_batch_dimension(batch_dim);
|
||||||
|
dims.set_output_batch_dimension(batch_dim);
|
||||||
|
dims.set_input_feature_dimension(feature_dim);
|
||||||
|
dims.set_output_feature_dimension(feature_dim);
|
||||||
|
dims.set_kernel_input_feature_dimension(num_spatial_dims);
|
||||||
|
dims.set_kernel_output_feature_dimension(num_spatial_dims + 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
|
||||||
|
dims.add_input_spatial_dimensions(dim);
|
||||||
|
dims.add_kernel_spatial_dimensions(i);
|
||||||
|
dims.add_output_spatial_dimensions(dim);
|
||||||
|
window_strides[i] = strides_.at(dim);
|
||||||
|
rhs_dilation[i] = dilations_.at(dim);
|
||||||
|
|
||||||
|
int64 unused_output_size;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetWindowedOutputSizeVerboseV2(
|
||||||
|
input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i],
|
||||||
|
window_strides[i], padding_, &unused_output_size,
|
||||||
|
&padding[i].first, &padding[i].second));
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::ComputationDataHandle conv =
|
||||||
|
builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
|
||||||
|
padding, lhs_dilation, rhs_dilation, dims);
|
||||||
|
ctx->SetOutput(0, conv);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<int32> ksizes_;
|
||||||
|
std::vector<int32> dilations_;
|
||||||
|
std::vector<int32> strides_;
|
||||||
|
Padding padding_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
98
tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
Normal file
98
tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class MatrixBandPartOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit MatrixBandPartOp(OpKernelConstruction* context)
|
||||||
|
: XlaOpKernel(context) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
|
const TensorShape input_shape = context->InputShape(0);
|
||||||
|
// Preliminary validation of sizes.
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"input must be at least 2-dim, received shape: ",
|
||||||
|
input_shape.DebugString()));
|
||||||
|
|
||||||
|
const TensorShape num_lower_in_shape = context->InputShape(1);
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape),
|
||||||
|
errors::InvalidArgument("num_lower must be scalar, got shape ",
|
||||||
|
num_lower_in_shape.DebugString()));
|
||||||
|
|
||||||
|
const TensorShape num_upper_in_shape = context->InputShape(2);
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape),
|
||||||
|
errors::InvalidArgument("num_upper must be scalar, got shape ",
|
||||||
|
num_upper_in_shape.DebugString()));
|
||||||
|
|
||||||
|
xla::ComputationBuilder* builder = context->builder();
|
||||||
|
xla::ComputationDataHandle input = context->Input(0);
|
||||||
|
xla::ComputationDataHandle num_lower = context->Input(1);
|
||||||
|
xla::ComputationDataHandle num_upper = context->Input(2);
|
||||||
|
DataType input_type = context->input_type(0);
|
||||||
|
DataType index_type = context->input_type(1);
|
||||||
|
|
||||||
|
TensorShape batch_shape = input_shape;
|
||||||
|
batch_shape.RemoveLastDims(2);
|
||||||
|
const int64 m = input_shape.dim_size(input_shape.dims() - 2);
|
||||||
|
const int64 n = input_shape.dim_size(input_shape.dims() - 1);
|
||||||
|
|
||||||
|
// Compute 'offset', which is how many diagonals we are above/below the
|
||||||
|
// diagonal.
|
||||||
|
xla::ComputationDataHandle iota_m;
|
||||||
|
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
|
||||||
|
|
||||||
|
xla::ComputationDataHandle iota_n;
|
||||||
|
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
|
||||||
|
|
||||||
|
auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
|
||||||
|
/*broadcast_dimensions=*/{0});
|
||||||
|
|
||||||
|
// If num_lower or num_upper are negative, include all lower/upper
|
||||||
|
// diagonals.
|
||||||
|
auto zero_index = XlaHelpers::Zero(builder, index_type);
|
||||||
|
num_lower = builder->Select(
|
||||||
|
builder->Lt(num_lower, zero_index),
|
||||||
|
XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower);
|
||||||
|
num_upper = builder->Select(
|
||||||
|
builder->Lt(num_upper, zero_index),
|
||||||
|
XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper);
|
||||||
|
|
||||||
|
auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset),
|
||||||
|
builder->Le(offset, num_upper));
|
||||||
|
indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
|
||||||
|
|
||||||
|
auto zero_input = XlaHelpers::Zero(builder, input_type);
|
||||||
|
auto output = builder->Select(
|
||||||
|
indicator, input,
|
||||||
|
builder->Broadcast(zero_input, input_shape.dim_sizes()));
|
||||||
|
|
||||||
|
context->SetOutput(0, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp);
|
||||||
|
};
|
||||||
|
REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
93
tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
Normal file
93
tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class MatrixSetDiagOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit MatrixSetDiagOp(OpKernelConstruction* context)
|
||||||
|
: XlaOpKernel(context) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
|
const TensorShape input_shape = context->InputShape(0);
|
||||||
|
const TensorShape diag_shape = context->InputShape(1);
|
||||||
|
|
||||||
|
const int rank = input_shape.dims();
|
||||||
|
|
||||||
|
// Preliminary validation of sizes.
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"input must be at least 2-dim, received shape: ",
|
||||||
|
input_shape.DebugString()));
|
||||||
|
|
||||||
|
// Check to make sure the last dimension of diag is equal to the smaller of
|
||||||
|
// the last two dimensions of input.
|
||||||
|
const int64 m = input_shape.dim_size(rank - 2);
|
||||||
|
const int64 n = input_shape.dim_size(rank - 1);
|
||||||
|
const int64 min_dim = std::min(m, n);
|
||||||
|
|
||||||
|
TensorShape batch_shape = input_shape;
|
||||||
|
batch_shape.RemoveLastDims(2);
|
||||||
|
|
||||||
|
TensorShape expected_diag_shape = batch_shape;
|
||||||
|
expected_diag_shape.AddDim(min_dim);
|
||||||
|
OP_REQUIRES(context, expected_diag_shape == diag_shape,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"must have diagonal.shape == input.shape[:-2] + "
|
||||||
|
"min(input.shape[-2:]), but received input shape: ",
|
||||||
|
input_shape.DebugString(),
|
||||||
|
" and diagonal shape: ", diag_shape.DebugString()));
|
||||||
|
|
||||||
|
xla::ComputationBuilder* builder = context->builder();
|
||||||
|
xla::ComputationDataHandle input = context->Input(0);
|
||||||
|
xla::ComputationDataHandle diag = context->Input(1);
|
||||||
|
|
||||||
|
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
|
||||||
|
|
||||||
|
// Create an indicator tensor that is true only on the diagonal.
|
||||||
|
xla::ComputationDataHandle iota_m;
|
||||||
|
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
|
||||||
|
xla::ComputationDataHandle iota_n;
|
||||||
|
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
|
||||||
|
auto indicator = builder->Eq(iota_m,
|
||||||
|
builder->Broadcast(iota_n, {m}),
|
||||||
|
/*broadcast_dimensions=*/{0});
|
||||||
|
indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
|
||||||
|
|
||||||
|
// Broadcast diag up to the input shape. Use an implicit broadcast (Add)
|
||||||
|
// because we need to broadcast on the right.
|
||||||
|
std::vector<int64> diag_broadcast_dims(rank - 1);
|
||||||
|
std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0);
|
||||||
|
if (min_dim != m) {
|
||||||
|
diag_broadcast_dims.back() = rank - 1;
|
||||||
|
}
|
||||||
|
diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()),
|
||||||
|
/*broadcast_dimensions=*/diag_broadcast_dims);
|
||||||
|
|
||||||
|
auto output = builder->Select(indicator, diag, input);
|
||||||
|
context->SetOutput(0, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,50 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class MatrixTriangularSolveOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit MatrixTriangularSolveOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
auto result = TriangularSolve(
|
||||||
|
ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true,
|
||||||
|
/*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_);
|
||||||
|
if (!result.ok()) {
|
||||||
|
ctx->SetStatus(result.status());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ctx->SetOutput(0, result.ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool lower_;
|
||||||
|
bool adjoint_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -37,6 +37,7 @@ class PoolingOp : public XlaOpKernel {
|
|||||||
public:
|
public:
|
||||||
PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims)
|
PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims)
|
||||||
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
|
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
|
||||||
|
if (ctx->num_inputs() == 1) {
|
||||||
std::vector<int32> ksize_int;
|
std::vector<int32> ksize_int;
|
||||||
std::vector<int32> stride_int;
|
std::vector<int32> stride_int;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
|
||||||
@ -53,6 +54,7 @@ class PoolingOp : public XlaOpKernel {
|
|||||||
ksize_.push_back(ksize_int[i]);
|
ksize_.push_back(ksize_int[i]);
|
||||||
stride_.push_back(stride_int[i]);
|
stride_.push_back(stride_int[i]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Padding padding;
|
Padding padding;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
|
||||||
padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
|
padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
|
||||||
@ -77,6 +79,33 @@ class PoolingOp : public XlaOpKernel {
|
|||||||
xla::ComputationDataHandle input = ctx->Input(0);
|
xla::ComputationDataHandle input = ctx->Input(0);
|
||||||
const TensorShape input_shape = ctx->InputShape(0);
|
const TensorShape input_shape = ctx->InputShape(0);
|
||||||
|
|
||||||
|
std::vector<int64> ksize = ksize_;
|
||||||
|
std::vector<int64> stride = stride_;
|
||||||
|
if (ctx->num_inputs() != 1) {
|
||||||
|
const TensorShape ksize_shape = ctx->InputShape(1);
|
||||||
|
// Validate input sizes.
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
|
||||||
|
errors::InvalidArgument("ksize must be a vector, not shape ",
|
||||||
|
ksize_shape.DebugString()));
|
||||||
|
OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
|
||||||
|
errors::InvalidArgument("Sliding window ksize field must "
|
||||||
|
"specify ",
|
||||||
|
num_dims(), " dimensions"));
|
||||||
|
ksize.clear();
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
|
||||||
|
|
||||||
|
const TensorShape stride_shape = ctx->InputShape(2);
|
||||||
|
// Validate input sizes.
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
|
||||||
|
errors::InvalidArgument("stride must be a vector, not shape ",
|
||||||
|
stride_shape.DebugString()));
|
||||||
|
OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
|
||||||
|
errors::InvalidArgument("Sliding window stride field must "
|
||||||
|
"specify ",
|
||||||
|
num_dims(), " dimensions"));
|
||||||
|
stride.clear();
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
|
||||||
|
}
|
||||||
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
|
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
|
||||||
errors::InvalidArgument("Input to ", type_string(),
|
errors::InvalidArgument("Input to ", type_string(),
|
||||||
" operator must have ", num_dims(),
|
" operator must have ", num_dims(),
|
||||||
@ -84,8 +113,8 @@ class PoolingOp : public XlaOpKernel {
|
|||||||
|
|
||||||
const DataType type = input_type(0);
|
const DataType type = input_type(0);
|
||||||
xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow(
|
xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow(
|
||||||
input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize_,
|
input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize,
|
||||||
stride_, padding_);
|
stride, padding_);
|
||||||
ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape));
|
ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,6 +159,10 @@ class MaxPool2DOp : public MaxPoolOp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
|
REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
|
||||||
|
REGISTER_XLA_OP(Name("MaxPoolV2")
|
||||||
|
.CompileTimeConstInput("ksize")
|
||||||
|
.CompileTimeConstInput("strides"),
|
||||||
|
MaxPool2DOp);
|
||||||
|
|
||||||
class MaxPool3DOp : public MaxPoolOp {
|
class MaxPool3DOp : public MaxPoolOp {
|
||||||
public:
|
public:
|
||||||
@ -243,22 +276,44 @@ class MaxPoolGradOp : public XlaOpKernel {
|
|||||||
public:
|
public:
|
||||||
MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
|
MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
|
||||||
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
|
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
|
||||||
|
if (ctx->num_inputs() == 3) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
|
||||||
OP_REQUIRES(ctx, ksize_.size() == num_dims(),
|
|
||||||
errors::InvalidArgument("Sliding window ksize field must "
|
|
||||||
"specify ",
|
|
||||||
num_dims(), " dimensions"));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
|
||||||
OP_REQUIRES(ctx, stride_.size() == num_dims(),
|
}
|
||||||
errors::InvalidArgument("Sliding window strides field must "
|
|
||||||
"specify ",
|
|
||||||
num_dims(), " dimensions"));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
|
||||||
}
|
}
|
||||||
|
|
||||||
int num_dims() const { return num_spatial_dims_ + 2; }
|
int num_dims() const { return num_spatial_dims_ + 2; }
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
if (ctx->num_inputs() != 3) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, ctx->num_inputs() == 5,
|
||||||
|
errors::InvalidArgument("Must supply ksize and stride arguments."));
|
||||||
|
const TensorShape ksize_shape = ctx->InputShape(3);
|
||||||
|
// Validate input sizes.
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
|
||||||
|
errors::InvalidArgument("ksize must be a vector, not shape ",
|
||||||
|
ksize_shape.DebugString()));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
|
||||||
|
|
||||||
|
const TensorShape stride_shape = ctx->InputShape(4);
|
||||||
|
// Validate input sizes.
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
|
||||||
|
errors::InvalidArgument("stride must be a vector, not shape ",
|
||||||
|
stride_shape.DebugString()));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, ksize_.size() == num_dims(),
|
||||||
|
errors::InvalidArgument("Sliding window ksize field must "
|
||||||
|
"specify ",
|
||||||
|
num_dims(), " dimensions"));
|
||||||
|
OP_REQUIRES(ctx, stride_.size() == num_dims(),
|
||||||
|
errors::InvalidArgument("Sliding window strides field must "
|
||||||
|
"specify ",
|
||||||
|
num_dims(), " dimensions"));
|
||||||
|
|
||||||
const TensorShape tensor_in_shape = ctx->InputShape(0);
|
const TensorShape tensor_in_shape = ctx->InputShape(0);
|
||||||
const TensorShape tensor_out_shape = ctx->InputShape(1);
|
const TensorShape tensor_out_shape = ctx->InputShape(1);
|
||||||
const TensorShape out_backprop_shape = ctx->InputShape(2);
|
const TensorShape out_backprop_shape = ctx->InputShape(2);
|
||||||
@ -315,6 +370,10 @@ class MaxPool2DGradOp : public MaxPoolGradOp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
|
REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
|
||||||
|
REGISTER_XLA_OP(Name("MaxPoolGradV2")
|
||||||
|
.CompileTimeConstInput("ksize")
|
||||||
|
.CompileTimeConstInput("strides"),
|
||||||
|
MaxPool2DGradOp);
|
||||||
|
|
||||||
class MaxPool3DGradOp : public MaxPoolGradOp {
|
class MaxPool3DGradOp : public MaxPoolGradOp {
|
||||||
public:
|
public:
|
||||||
|
182
tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
Normal file
182
tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class ReverseSequenceOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit ReverseSequenceOp(OpKernelConstruction* context)
|
||||||
|
: XlaOpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* context) override {
|
||||||
|
const TensorShape input_shape = context->InputShape(0);
|
||||||
|
const TensorShape seq_lens_shape = context->InputShape(1);
|
||||||
|
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape),
|
||||||
|
errors::InvalidArgument("seq_lens input must be 1-dim, not ",
|
||||||
|
seq_lens_shape.dims()));
|
||||||
|
OP_REQUIRES(context, batch_dim_ != seq_dim_,
|
||||||
|
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, seq_dim_ < input_shape.dims(),
|
||||||
|
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
|
||||||
|
seq_dim_, " vs. ", input_shape.dims(), ")"));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, batch_dim_ < input_shape.dims(),
|
||||||
|
errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
|
||||||
|
batch_dim_, " vs. ", input_shape.dims(), ")"));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context,
|
||||||
|
seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_),
|
||||||
|
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_,
|
||||||
|
"), ", "(", seq_lens_shape.num_elements(),
|
||||||
|
" vs. ", input_shape.dim_size(batch_dim_)));
|
||||||
|
|
||||||
|
xla::ComputationBuilder* builder = context->builder();
|
||||||
|
const auto input = context->Input(0);
|
||||||
|
const auto seq_lens = context->Input(1);
|
||||||
|
|
||||||
|
const int64 batch_size = input_shape.dim_size(batch_dim_);
|
||||||
|
|
||||||
|
const DataType input_type = context->input_type(0);
|
||||||
|
const DataType seq_lens_type = context->input_type(1);
|
||||||
|
const int64 max_seq_len = input_shape.dim_size(seq_dim_);
|
||||||
|
|
||||||
|
xla::Shape input_xla_shape;
|
||||||
|
OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape,
|
||||||
|
&input_xla_shape));
|
||||||
|
xla::Shape seq_lens_xla_shape;
|
||||||
|
OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape,
|
||||||
|
&seq_lens_xla_shape));
|
||||||
|
|
||||||
|
const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({
|
||||||
|
xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}),
|
||||||
|
seq_lens_xla_shape,
|
||||||
|
input_xla_shape,
|
||||||
|
});
|
||||||
|
|
||||||
|
// For each entry in the batch, reverse the sequence.
|
||||||
|
// TODO(b/65689298): generalize the Map() operator to non-scalar cases and
|
||||||
|
// use it here, instead of a While loop.
|
||||||
|
|
||||||
|
// Condition: lambda (i, _, _): i < batch_size
|
||||||
|
auto condition_builder =
|
||||||
|
builder->CreateSubBuilder("reverse_sequence_condition");
|
||||||
|
{
|
||||||
|
auto param = condition_builder->Parameter(0, tuple_shape, "param");
|
||||||
|
auto i = condition_builder->GetTupleElement(param, 0);
|
||||||
|
condition_builder->Lt(
|
||||||
|
i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type,
|
||||||
|
batch_size));
|
||||||
|
}
|
||||||
|
auto condition = condition_builder->Build();
|
||||||
|
OP_REQUIRES_OK(context, condition.status());
|
||||||
|
|
||||||
|
auto body_builder = builder->CreateSubBuilder("reverse_sequence_body");
|
||||||
|
{
|
||||||
|
auto param = body_builder->Parameter(0, tuple_shape, "param");
|
||||||
|
auto i = body_builder->GetTupleElement(param, 0);
|
||||||
|
auto seq_lens = body_builder->GetTupleElement(param, 1);
|
||||||
|
auto output = body_builder->GetTupleElement(param, 2);
|
||||||
|
|
||||||
|
// seq_len is the sequence length of the current batch element (rank 1)
|
||||||
|
auto seq_len = body_builder->DynamicSlice(
|
||||||
|
seq_lens, body_builder->Reshape(i, {1}), {1});
|
||||||
|
|
||||||
|
// Indices is the offset of the batch element in the input.
|
||||||
|
auto indices = body_builder->Broadcast(
|
||||||
|
XlaHelpers::Zero(body_builder.get(), seq_lens_type),
|
||||||
|
{input_shape.dims()});
|
||||||
|
indices = body_builder->DynamicUpdateSlice(
|
||||||
|
indices, body_builder->Reshape(i, {1}),
|
||||||
|
body_builder->Reshape(
|
||||||
|
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
|
||||||
|
batch_dim_),
|
||||||
|
{1}));
|
||||||
|
|
||||||
|
// slice_indices is the offset of the start of the reversed sequence in
|
||||||
|
// the input.
|
||||||
|
auto slice_indices = body_builder->DynamicUpdateSlice(
|
||||||
|
indices,
|
||||||
|
body_builder->Sub(XlaHelpers::IntegerLiteral(
|
||||||
|
body_builder.get(), seq_lens_type, max_seq_len),
|
||||||
|
seq_len),
|
||||||
|
body_builder->Reshape(
|
||||||
|
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
|
||||||
|
seq_dim_),
|
||||||
|
{1}));
|
||||||
|
|
||||||
|
// Slice out the reversed sequence. The slice will overflow the end of the
|
||||||
|
// sequence, and the contents of the overflow are implementation-defined.
|
||||||
|
// However, we will mask off these elements and replace them with elements
|
||||||
|
// from the original input so their values do not matter.
|
||||||
|
TensorShape slice_shape = input_shape;
|
||||||
|
slice_shape.set_dim(batch_dim_, 1);
|
||||||
|
auto slice = body_builder->DynamicSlice(output, slice_indices,
|
||||||
|
slice_shape.dim_sizes());
|
||||||
|
|
||||||
|
// Shift the reversed sequence to the left.
|
||||||
|
output = body_builder->DynamicUpdateSlice(output, slice, indices);
|
||||||
|
|
||||||
|
body_builder->Tuple(
|
||||||
|
{body_builder->Add(
|
||||||
|
i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
|
||||||
|
seq_lens, output});
|
||||||
|
}
|
||||||
|
auto body = body_builder->Build();
|
||||||
|
OP_REQUIRES_OK(context, body.status());
|
||||||
|
|
||||||
|
auto loop_output = builder->While(
|
||||||
|
condition.ValueOrDie(), body.ValueOrDie(),
|
||||||
|
builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
|
||||||
|
builder->Rev(input, {seq_dim_})}));
|
||||||
|
auto output = builder->GetTupleElement(loop_output, 2);
|
||||||
|
|
||||||
|
// Mask out elements after the sequence length.
|
||||||
|
xla::ComputationDataHandle iota;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
|
||||||
|
std::vector<int64> dims(input_shape.dims(), 1);
|
||||||
|
dims[batch_dim_] = batch_size;
|
||||||
|
auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_});
|
||||||
|
|
||||||
|
// Broadcast the mask up to the input shape.
|
||||||
|
mask =
|
||||||
|
builder->Or(mask, builder->Broadcast(builder->ConstantR0<bool>(false),
|
||||||
|
input_shape.dim_sizes()));
|
||||||
|
|
||||||
|
output = builder->Select(mask, output, input);
|
||||||
|
context->SetOutput(0, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32 batch_dim_;
|
||||||
|
int32 seq_dim_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -77,10 +77,8 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder,
|
|||||||
// Stack has not been initialized.
|
// Stack has not been initialized.
|
||||||
xla::ComputationDataHandle zero =
|
xla::ComputationDataHandle zero =
|
||||||
XlaHelpers::Zero(builder, resource->type());
|
XlaHelpers::Zero(builder, resource->type());
|
||||||
TF_RETURN_IF_ERROR(resource->SetValue(
|
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
|
||||||
dtype,
|
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
|
||||||
builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()),
|
|
||||||
builder->ConstantR0<int32>(0)})));
|
|
||||||
} else {
|
} else {
|
||||||
// Checks the expected shape matches the actual shape.
|
// Checks the expected shape matches the actual shape.
|
||||||
TensorShape actual_shape;
|
TensorShape actual_shape;
|
||||||
@ -119,8 +117,8 @@ class StackOp : public XlaOpKernel {
|
|||||||
string name = strings::StrCat("Stack: ", stack_name_);
|
string name = strings::StrCat("Stack: ", stack_name_);
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
|
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
|
||||||
value, &resource));
|
TensorShape(), value, /*tensor_array_size=*/size,
|
||||||
resource->set_tensor_array_size(size);
|
/*tensor_array_gradients=*/{}, &resource));
|
||||||
ctx->SetResourceOutput(0, resource);
|
ctx->SetResourceOutput(0, resource);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,10 +162,8 @@ class StackPushOp : public XlaOpKernel {
|
|||||||
|
|
||||||
// TODO(phawkins): We don't check the index is in bounds --- there is no
|
// TODO(phawkins): We don't check the index is in bounds --- there is no
|
||||||
// error mechanism in XLA.
|
// error mechanism in XLA.
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple(
|
||||||
ctx,
|
{b->DynamicUpdateSlice(ta, update, start_indices),
|
||||||
resource->SetValue(
|
|
||||||
dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices),
|
|
||||||
b->Add(index, b->ConstantR0<int32>(1))})));
|
b->Add(index, b->ConstantR0<int32>(1))})));
|
||||||
|
|
||||||
ctx->SetOutput(0, value);
|
ctx->SetOutput(0, value);
|
||||||
@ -208,7 +204,7 @@ class StackPopOp : public XlaOpKernel {
|
|||||||
xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
|
xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
|
||||||
|
|
||||||
index = b->Sub(index, b->ConstantR0<int32>(1));
|
index = b->Sub(index, b->ConstantR0<int32>(1));
|
||||||
OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index})));
|
OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
|
||||||
|
|
||||||
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
||||||
auto start_indices =
|
auto start_indices =
|
||||||
|
@ -231,6 +231,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
@ -252,9 +253,9 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||||
&strides_tensor));
|
&strides_tensor));
|
||||||
|
|
||||||
DataType lhs_type;
|
|
||||||
TensorShape lhs_shape;
|
TensorShape lhs_shape;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
|
xla::ComputationDataHandle lhs;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
|
||||||
|
|
||||||
const TensorShape rhs_shape = ctx->InputShape(4);
|
const TensorShape rhs_shape = ctx->InputShape(4);
|
||||||
|
|
||||||
@ -282,9 +283,6 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
|||||||
" does not match r-value shape ", rhs_shape.DebugString(),
|
" does not match r-value shape ", rhs_shape.DebugString(),
|
||||||
". Automatic broadcasting not yet implemented."));
|
". Automatic broadcasting not yet implemented."));
|
||||||
|
|
||||||
xla::ComputationDataHandle lhs;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
|
|
||||||
|
|
||||||
xla::ComputationDataHandle rhs = ctx->Input(4);
|
xla::ComputationDataHandle rhs = ctx->Input(4);
|
||||||
|
|
||||||
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||||
@ -320,13 +318,14 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
|||||||
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
|
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
|
||||||
}
|
}
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32 begin_mask_, end_mask_;
|
int32 begin_mask_, end_mask_;
|
||||||
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
|
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
|
||||||
DataType index_type_;
|
DataType index_type_;
|
||||||
|
DataType dtype_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
|
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
|
||||||
|
@ -62,15 +62,13 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
|
|||||||
|
|
||||||
TF_RET_CHECK(resource->tensor_array_size() >= 0)
|
TF_RET_CHECK(resource->tensor_array_size() >= 0)
|
||||||
<< resource->name() << " size " << resource->tensor_array_size();
|
<< resource->name() << " size " << resource->tensor_array_size();
|
||||||
TensorShape ta_shape;
|
|
||||||
ta_shape.AddDim(resource->tensor_array_size());
|
|
||||||
ta_shape.AppendShape(elem_shape);
|
|
||||||
|
|
||||||
if (!resource->initialized()) {
|
if (!resource->initialized()) {
|
||||||
xla::ComputationDataHandle zero =
|
xla::ComputationDataHandle zero =
|
||||||
XlaHelpers::Zero(builder, resource->type());
|
XlaHelpers::Zero(builder, resource->type());
|
||||||
TF_RETURN_IF_ERROR(resource->SetValue(
|
|
||||||
dtype, builder->Broadcast(zero, ta_shape.dim_sizes())));
|
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
|
||||||
|
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
|
||||||
} else {
|
} else {
|
||||||
// Checks the elem_shape matches the TensorArray shape.
|
// Checks the elem_shape matches the TensorArray shape.
|
||||||
auto shape_or_status = builder->GetShape(resource->value());
|
auto shape_or_status = builder->GetShape(resource->value());
|
||||||
@ -80,6 +78,10 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
|
|||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
|
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
|
||||||
|
|
||||||
|
TensorShape ta_shape;
|
||||||
|
ta_shape.AddDim(resource->tensor_array_size());
|
||||||
|
ta_shape.AppendShape(elem_shape);
|
||||||
if (ta_shape != shape) {
|
if (ta_shape != shape) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
|
"Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
|
||||||
@ -114,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
|
|||||||
Status GetTensorArrayShape(const XlaResource* resource,
|
Status GetTensorArrayShape(const XlaResource* resource,
|
||||||
xla::ComputationBuilder* builder,
|
xla::ComputationBuilder* builder,
|
||||||
TensorShape* shape) {
|
TensorShape* shape) {
|
||||||
TF_RETURN_IF_ERROR(resource->GetShape(builder, shape));
|
*shape = resource->shape();
|
||||||
if (shape->dims() < 1) {
|
shape->InsertDim(0, resource->tensor_array_size());
|
||||||
return errors::InvalidArgument("TensorArray rank must be >= 1");
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel {
|
|||||||
// Initializes the TensorArray value if we know the element shape.
|
// Initializes the TensorArray value if we know the element shape.
|
||||||
// Otherwise, defer initialization to the first write.
|
// Otherwise, defer initialization to the first write.
|
||||||
xla::ComputationDataHandle value;
|
xla::ComputationDataHandle value;
|
||||||
if (element_shape_.IsFullyDefined()) {
|
|
||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
|
if (element_shape_.IsFullyDefined()) {
|
||||||
CHECK(element_shape_.AsTensorShape(&shape));
|
CHECK(element_shape_.AsTensorShape(&shape));
|
||||||
TensorShape ta_shape;
|
TensorShape ta_shape;
|
||||||
ta_shape.AddDim(size);
|
ta_shape.AddDim(size);
|
||||||
@ -175,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel {
|
|||||||
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
|
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
|
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
|
||||||
dtype_, value, &var));
|
dtype_, shape, value, /*tensor_array_size=*/size,
|
||||||
var->set_tensor_array_size(size);
|
/*tensor_array_gradients=*/{}, &var));
|
||||||
ctx->SetResourceOutput(0, var);
|
ctx->SetResourceOutput(0, var);
|
||||||
|
|
||||||
Tensor flow(DT_FLOAT, TensorShape({}));
|
Tensor flow(DT_FLOAT, TensorShape({}));
|
||||||
@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
|
|||||||
xla::ComputationDataHandle written =
|
xla::ComputationDataHandle written =
|
||||||
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
|
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written));
|
OP_REQUIRES_OK(ctx, resource->SetValue(written));
|
||||||
ctx->SetOutput(0, flow);
|
ctx->SetOutput(0, flow);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -421,7 +421,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta));
|
OP_REQUIRES_OK(ctx, resource->SetValue(ta));
|
||||||
ctx->SetOutput(0, flow);
|
ctx->SetOutput(0, flow);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -525,9 +525,8 @@ class TensorArraySplitOp : public XlaOpKernel {
|
|||||||
value_shape.DebugString(), " vs. ",
|
value_shape.DebugString(), " vs. ",
|
||||||
ta_shape.DebugString()));
|
ta_shape.DebugString()));
|
||||||
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
|
||||||
ctx, resource->SetValue(
|
ta, b->Reshape(value, ta_shape.dim_sizes()))));
|
||||||
dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()))));
|
|
||||||
|
|
||||||
ctx->SetOutput(0, flow);
|
ctx->SetOutput(0, flow);
|
||||||
}
|
}
|
||||||
|
@ -32,9 +32,24 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
|
|||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::ComputationDataHandle handle;
|
xla::ComputationDataHandle handle;
|
||||||
xla::ComputationBuilder* b = ctx->builder();
|
xla::ComputationBuilder* b = ctx->builder();
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
|
DataType type = ctx->input_type(1);
|
||||||
|
TensorShape var_shape;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
|
||||||
|
|
||||||
|
TensorShape alpha_shape = ctx->InputShape(1);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
|
||||||
|
errors::InvalidArgument("alpha is not a scalar: ",
|
||||||
|
alpha_shape.DebugString()));
|
||||||
|
|
||||||
|
TensorShape delta_shape = ctx->InputShape(2);
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, var_shape.IsSameSize(delta_shape),
|
||||||
|
errors::InvalidArgument("var and delta do not have the same shape: ",
|
||||||
|
var_shape.DebugString(), " vs ",
|
||||||
|
delta_shape.DebugString()));
|
||||||
|
|
||||||
handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
|
handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
|
||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(
|
REGISTER_XLA_OP(
|
||||||
@ -52,18 +67,10 @@ class ResourceApplyMomentum : public XlaOpKernel {
|
|||||||
|
|
||||||
DataType type = ctx->input_type(2);
|
DataType type = ctx->input_type(2);
|
||||||
|
|
||||||
DataType var_type, accum_type;
|
|
||||||
TensorShape var_shape, accum_shape;
|
TensorShape var_shape, accum_shape;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
xla::ComputationDataHandle var, accum;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
|
||||||
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, type == var_type && type == accum_type,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Types of variable arguments to ResourceApplyMomentum must match: ",
|
|
||||||
DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
|
|
||||||
DataTypeString(accum_type)));
|
|
||||||
|
|
||||||
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -86,10 +93,6 @@ class ResourceApplyMomentum : public XlaOpKernel {
|
|||||||
errors::InvalidArgument("momentum is not a scalar: ",
|
errors::InvalidArgument("momentum is not a scalar: ",
|
||||||
momentum_shape.DebugString()));
|
momentum_shape.DebugString()));
|
||||||
|
|
||||||
xla::ComputationDataHandle var, accum;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
|
|
||||||
|
|
||||||
xla::ComputationDataHandle lr = ctx->Input(2);
|
xla::ComputationDataHandle lr = ctx->Input(2);
|
||||||
xla::ComputationDataHandle grad = ctx->Input(3);
|
xla::ComputationDataHandle grad = ctx->Input(3);
|
||||||
xla::ComputationDataHandle momentum = ctx->Input(4);
|
xla::ComputationDataHandle momentum = ctx->Input(4);
|
||||||
@ -122,18 +125,10 @@ class ResourceApplyAdagrad : public XlaOpKernel {
|
|||||||
|
|
||||||
DataType type = ctx->input_type(2);
|
DataType type = ctx->input_type(2);
|
||||||
|
|
||||||
DataType var_type, accum_type;
|
|
||||||
TensorShape var_shape, accum_shape;
|
TensorShape var_shape, accum_shape;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
xla::ComputationDataHandle var, accum;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
|
||||||
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, type == var_type && type == accum_type,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Types of variable arguments to ResourceApplyAdagrad must match: ",
|
|
||||||
DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
|
|
||||||
DataTypeString(accum_type)));
|
|
||||||
|
|
||||||
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -151,9 +146,6 @@ class ResourceApplyAdagrad : public XlaOpKernel {
|
|||||||
"var and grad do not have the same shape",
|
"var and grad do not have the same shape",
|
||||||
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
||||||
|
|
||||||
xla::ComputationDataHandle var, accum;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
|
|
||||||
xla::ComputationDataHandle lr = ctx->Input(2);
|
xla::ComputationDataHandle lr = ctx->Input(2);
|
||||||
xla::ComputationDataHandle grad = ctx->Input(3);
|
xla::ComputationDataHandle grad = ctx->Input(3);
|
||||||
|
|
||||||
@ -175,18 +167,11 @@ class ResourceApplyAdam : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
DataType var_type, m_type, v_type;
|
|
||||||
TensorShape var_shape, m_shape, v_shape;
|
TensorShape var_shape, m_shape, v_shape;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
xla::ComputationDataHandle var, m, v;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape));
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape));
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Types of variable arguments to ResourceApplyRMSProp must match: ",
|
|
||||||
DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ",
|
|
||||||
DataTypeString(m_type), " vs. ", DataTypeString(v_type)));
|
|
||||||
|
|
||||||
TensorShape beta1_power_shape = ctx->InputShape(3);
|
TensorShape beta1_power_shape = ctx->InputShape(3);
|
||||||
TensorShape beta2_power_shape = ctx->InputShape(4);
|
TensorShape beta2_power_shape = ctx->InputShape(4);
|
||||||
@ -228,10 +213,6 @@ class ResourceApplyAdam : public XlaOpKernel {
|
|||||||
"var and grad do not have the same shape",
|
"var and grad do not have the same shape",
|
||||||
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
||||||
|
|
||||||
xla::ComputationDataHandle var, m, v;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v));
|
|
||||||
xla::ComputationDataHandle beta1_power = ctx->Input(3);
|
xla::ComputationDataHandle beta1_power = ctx->Input(3);
|
||||||
xla::ComputationDataHandle beta2_power = ctx->Input(4);
|
xla::ComputationDataHandle beta2_power = ctx->Input(4);
|
||||||
xla::ComputationDataHandle lr = ctx->Input(5);
|
xla::ComputationDataHandle lr = ctx->Input(5);
|
||||||
@ -278,18 +259,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
|
|||||||
|
|
||||||
DataType type = ctx->input_type(3);
|
DataType type = ctx->input_type(3);
|
||||||
|
|
||||||
DataType var_type, ms_type, mom_type;
|
|
||||||
TensorShape var_shape, ms_shape, mom_shape;
|
TensorShape var_shape, ms_shape, mom_shape;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
xla::ComputationDataHandle var, ms, mom;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape));
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape));
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, type == var_type && type == ms_type && type == mom_type,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Types of variable arguments to ResourceApplyRMSProp must match: ",
|
|
||||||
DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ",
|
|
||||||
DataTypeString(ms_type), " vs. ", DataTypeString(mom_type)));
|
|
||||||
|
|
||||||
TensorShape lr_shape = ctx->InputShape(3);
|
TensorShape lr_shape = ctx->InputShape(3);
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
|
||||||
@ -323,10 +297,6 @@ class ResourceApplyRMSProp : public XlaOpKernel {
|
|||||||
"var and grad do not have the same shape",
|
"var and grad do not have the same shape",
|
||||||
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
||||||
|
|
||||||
xla::ComputationDataHandle var, ms, mom;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom));
|
|
||||||
xla::ComputationDataHandle lr = ctx->Input(3);
|
xla::ComputationDataHandle lr = ctx->Input(3);
|
||||||
xla::ComputationDataHandle rho = ctx->Input(4);
|
xla::ComputationDataHandle rho = ctx->Input(4);
|
||||||
xla::ComputationDataHandle momentum = ctx->Input(5);
|
xla::ComputationDataHandle momentum = ctx->Input(5);
|
||||||
@ -373,20 +343,11 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
|
|||||||
bool has_l2_shrinkage) {
|
bool has_l2_shrinkage) {
|
||||||
xla::ComputationBuilder* b = ctx->builder();
|
xla::ComputationBuilder* b = ctx->builder();
|
||||||
|
|
||||||
DataType var_type, accum_type, linear_type;
|
|
||||||
TensorShape var_shape, accum_shape, linear_shape;
|
TensorShape var_shape, accum_shape, linear_shape;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
xla::ComputationDataHandle var, accum, linear;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
|
||||||
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
|
||||||
ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape));
|
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, dtype == var_type && dtype == accum_type && dtype == linear_type,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Types of variable arguments to ResourceApplyFtrlV2 must match: ",
|
|
||||||
DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ",
|
|
||||||
DataTypeString(accum_type), " and ", DataTypeString(linear_type)));
|
|
||||||
|
|
||||||
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -438,10 +399,6 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
|
|||||||
errors::InvalidArgument("lr_power is not a scalar: ",
|
errors::InvalidArgument("lr_power is not a scalar: ",
|
||||||
lr_power_shape.DebugString()));
|
lr_power_shape.DebugString()));
|
||||||
|
|
||||||
xla::ComputationDataHandle var, accum, linear;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear));
|
|
||||||
xla::ComputationDataHandle grad = ctx->Input(3);
|
xla::ComputationDataHandle grad = ctx->Input(3);
|
||||||
xla::ComputationDataHandle lr = ctx->Input(4);
|
xla::ComputationDataHandle lr = ctx->Input(4);
|
||||||
xla::ComputationDataHandle l1 = ctx->Input(5);
|
xla::ComputationDataHandle l1 = ctx->Input(5);
|
||||||
|
@ -50,18 +50,41 @@ XLAJIT_MAKE_UNARY(Conj, b->Conj(x));
|
|||||||
// Return x if x>0, otherwise -x.
|
// Return x if x>0, otherwise -x.
|
||||||
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
|
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
|
||||||
|
|
||||||
|
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
|
||||||
|
XLAJIT_MAKE_UNARY(
|
||||||
|
Acos,
|
||||||
|
b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
|
||||||
|
b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
|
||||||
|
b->Mul(x, x)),
|
||||||
|
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
|
||||||
|
b->Add(XlaHelpers::One(b, input_type(0)), x))));
|
||||||
|
|
||||||
// acosh(x) = log(x + sqrt(x^2 - 1))
|
// acosh(x) = log(x + sqrt(x^2 - 1))
|
||||||
XLAJIT_MAKE_UNARY(
|
XLAJIT_MAKE_UNARY(
|
||||||
Acosh,
|
Acosh,
|
||||||
b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x),
|
b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x),
|
||||||
XlaHelpers::One(b, input_type(0))),
|
XlaHelpers::One(b, input_type(0))),
|
||||||
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
|
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
|
||||||
|
|
||||||
|
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
||||||
|
XLAJIT_MAKE_UNARY(
|
||||||
|
Asin,
|
||||||
|
b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
|
||||||
|
b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)),
|
||||||
|
b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
|
||||||
|
b->Mul(x, x)),
|
||||||
|
XlaHelpers::FloatLiteral(b, input_type(0),
|
||||||
|
0.5))))));
|
||||||
|
|
||||||
// asinh(x) = log(x + sqrt(x^2 + 1))
|
// asinh(x) = log(x + sqrt(x^2 + 1))
|
||||||
XLAJIT_MAKE_UNARY(
|
XLAJIT_MAKE_UNARY(
|
||||||
Asinh,
|
Asinh,
|
||||||
b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x),
|
b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x),
|
||||||
XlaHelpers::One(b, input_type(0))),
|
XlaHelpers::One(b, input_type(0))),
|
||||||
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
|
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
|
||||||
|
|
||||||
|
XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0))));
|
||||||
|
|
||||||
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
|
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
|
||||||
XLAJIT_MAKE_UNARY(
|
XLAJIT_MAKE_UNARY(
|
||||||
Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x),
|
Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x),
|
||||||
|
@ -33,21 +33,29 @@ class VarIsInitializedOp : public XlaOpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::ComputationDataHandle handle;
|
XlaResource* variable;
|
||||||
bool initialized = ctx->ReadVariableInput(0, &handle).ok();
|
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
|
||||||
ctx->SetOutput(0, ctx->builder()->ConstantR0<bool>(initialized));
|
ctx->SetOutput(0,
|
||||||
|
ctx->builder()->ConstantR0<bool>(variable->initialized()));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
|
REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
|
||||||
|
|
||||||
class ReadVariableOp : public XlaOpKernel {
|
class ReadVariableOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::ComputationDataHandle handle;
|
xla::ComputationDataHandle handle;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
|
||||||
ctx->SetOutput(0, handle);
|
ctx->SetOutput(0, handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp);
|
REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp);
|
||||||
|
|
||||||
@ -65,10 +73,12 @@ class AssignAddVariableOp : public XlaOpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
DataType type = ctx->input_type(1);
|
||||||
xla::ComputationDataHandle handle;
|
xla::ComputationDataHandle handle;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
|
||||||
handle = ctx->builder()->Add(handle, ctx->Input(1));
|
handle = ctx->builder()->Add(handle, ctx->Input(1));
|
||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(
|
REGISTER_XLA_OP(
|
||||||
@ -79,10 +89,12 @@ class AssignSubVariableOp : public XlaOpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
DataType type = ctx->input_type(1);
|
||||||
xla::ComputationDataHandle handle;
|
xla::ComputationDataHandle handle;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
|
||||||
handle = ctx->builder()->Sub(handle, ctx->Input(1));
|
handle = ctx->builder()->Sub(handle, ctx->Input(1));
|
||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(
|
REGISTER_XLA_OP(
|
||||||
@ -95,28 +107,19 @@ class ResourceGatherOp : public XlaOpKernel {
|
|||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::ComputationBuilder* builder = ctx->builder();
|
xla::ComputationBuilder* builder = ctx->builder();
|
||||||
|
|
||||||
// Get the shape of the resource tensor.
|
DataType type = ctx->expected_output_dtype(0);
|
||||||
|
|
||||||
TensorShape resource_shape;
|
TensorShape resource_shape;
|
||||||
DataType resource_dtype;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape));
|
|
||||||
|
|
||||||
DataType expected_output_dtype = ctx->expected_output_dtype(0);
|
|
||||||
OP_REQUIRES(ctx, resource_dtype == expected_output_dtype,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Variable dtype is ", DataTypeString(resource_dtype),
|
|
||||||
" but expected output dtype is ",
|
|
||||||
DataTypeString(expected_output_dtype), "."));
|
|
||||||
|
|
||||||
xla::ComputationDataHandle resource_handle;
|
xla::ComputationDataHandle resource_handle;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle));
|
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
|
||||||
|
&resource_handle));
|
||||||
|
|
||||||
auto indices = ctx->Input(1);
|
auto indices = ctx->Input(1);
|
||||||
auto indices_shape = ctx->InputShape(1);
|
auto indices_shape = ctx->InputShape(1);
|
||||||
DataType index_type = ctx->input_type(1);
|
DataType index_type = ctx->input_type(1);
|
||||||
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
|
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
|
||||||
ctx, resource_handle, resource_shape, indices, indices_shape, 0,
|
ctx, resource_handle, resource_shape, indices, indices_shape, 0, type,
|
||||||
resource_dtype, index_type, builder);
|
index_type, builder);
|
||||||
ctx->SetOutput(0, gather);
|
ctx->SetOutput(0, gather);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -58,9 +58,8 @@ Status MakeXlaCompilerArgumentsFromInputs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
arg.type = resource->type();
|
arg.type = resource->type();
|
||||||
if (arg.initialized) {
|
arg.shape = resource->shape();
|
||||||
TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape));
|
if (!arg.initialized) {
|
||||||
} else {
|
|
||||||
*has_uninitialized_vars = true;
|
*has_uninitialized_vars = true;
|
||||||
}
|
}
|
||||||
arg.tensor_array_size = resource->tensor_array_size();
|
arg.tensor_array_size = resource->tensor_array_size();
|
||||||
@ -70,14 +69,13 @@ Status MakeXlaCompilerArgumentsFromInputs(
|
|||||||
arg.name = resource->name();
|
arg.name = resource->name();
|
||||||
VLOG(2) << " resource " << resource->name()
|
VLOG(2) << " resource " << resource->name()
|
||||||
<< " type: " << DataTypeString(arg.type)
|
<< " type: " << DataTypeString(arg.type)
|
||||||
<< " shape: " << xla::ShapeUtil::HumanString(arg.shape)
|
<< " shape: " << arg.shape.DebugString()
|
||||||
<< " initialized: " << arg.initialized;
|
<< " initialized: " << arg.initialized;
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
arg.kind = XlaCompiler::Argument::kParameter;
|
arg.kind = XlaCompiler::Argument::kParameter;
|
||||||
arg.type = ctx->input_type(i);
|
arg.type = ctx->input_type(i);
|
||||||
TF_RETURN_IF_ERROR(
|
arg.shape = ctx->InputShape(i);
|
||||||
TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -154,17 +152,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
XlaCompiler::Argument& arg = arguments[update.input_index];
|
XlaCompiler::Argument& arg = arguments[update.input_index];
|
||||||
if (!arg.initialized) {
|
if (!arg.initialized) {
|
||||||
VLOG(2) << "Update shape for argument " << update.input_index << " "
|
VLOG(2) << "Update shape for argument " << update.input_index << " "
|
||||||
<< xla::ShapeUtil::HumanString(update.shape);
|
<< update.shape.DebugString();
|
||||||
arg.initialized = true;
|
arg.initialized = true;
|
||||||
|
|
||||||
xla::Shape shape = update.shape;
|
arg.shape = update.shape;
|
||||||
if (!update.tensor_array_gradients_accessed.empty()) {
|
OP_REQUIRES_OK(ctx,
|
||||||
shape = xla::ShapeUtil::GetTupleElementShape(shape, 0);
|
resource->SetTypeAndShape(update.type, update.shape));
|
||||||
}
|
|
||||||
std::unique_ptr<xla::Literal> zero =
|
OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
|
||||||
xla::Literal::CreateFromShape(shape);
|
|
||||||
OP_REQUIRES_OK(ctx, resource->SetValue(
|
|
||||||
update.type, builder->ConstantLiteral(*zero)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add any TensorArray gradients touched by the body to the enclosing
|
// Add any TensorArray gradients touched by the body to the enclosing
|
||||||
@ -182,9 +177,6 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
for (const auto& gradient : resource->tensor_array_gradients()) {
|
for (const auto& gradient : resource->tensor_array_gradients()) {
|
||||||
arg.tensor_array_gradients.insert(gradient.first);
|
arg.tensor_array_gradients.insert(gradient.first);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recompute the argument shape.
|
|
||||||
OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape));
|
|
||||||
}
|
}
|
||||||
// Recompile the body with the "correct" resource shapes.
|
// Recompile the body with the "correct" resource shapes.
|
||||||
VLOG(1) << "Recompiling body with corrected resource shapes";
|
VLOG(1) << "Recompiling body with corrected resource shapes";
|
||||||
@ -292,13 +284,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
resource->SetFromPack(
|
resource->SetFromPack(
|
||||||
arguments[update.input_index].tensor_array_gradients,
|
arguments[update.input_index].tensor_array_gradients,
|
||||||
builder->GetTupleElement(while_result, pos),
|
builder->GetTupleElement(while_result, pos), builder));
|
||||||
/*reset_initial_values=*/false, builder));
|
|
||||||
}
|
}
|
||||||
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
|
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
|
||||||
<< " name: " << resource->name() << " modified: " << update.modified
|
<< " name: " << resource->name() << " modified: " << update.modified
|
||||||
<< " type: " << DataTypeString(update.type)
|
<< " type: " << DataTypeString(update.type)
|
||||||
<< " shape: " << xla::ShapeUtil::HumanString(update.shape);
|
<< " shape: " << update.shape.DebugString();
|
||||||
// Copies the identity of the resource variable from input to output
|
// Copies the identity of the resource variable from input to output
|
||||||
// unchanged, even if the variable was not modified.
|
// unchanged, even if the variable was not modified.
|
||||||
ctx->op_kernel_context()->set_output(
|
ctx->op_kernel_context()->set_output(
|
||||||
|
@ -60,6 +60,8 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/client:computation",
|
"//tensorflow/compiler/xla/client:computation",
|
||||||
"//tensorflow/compiler/xla/client:computation_builder",
|
"//tensorflow/compiler/xla/client:computation_builder",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -25,11 +25,10 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// The current implementation simply unrolls the computation along the batch
|
|
||||||
// dimension.
|
|
||||||
xla::StatusOr<xla::ComputationDataHandle> BatchDot(
|
xla::StatusOr<xla::ComputationDataHandle> BatchDot(
|
||||||
xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
|
xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
|
||||||
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) {
|
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
|
||||||
|
bool conjugate_x, bool conjugate_y) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape,
|
||||||
builder->GetShape(x));
|
builder->GetShape(x));
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape,
|
||||||
@ -89,10 +88,10 @@ xla::StatusOr<xla::ComputationDataHandle> BatchDot(
|
|||||||
dimensions);
|
dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (x_shape->element_type() == xla::C64 && transpose_x) {
|
if (x_shape->element_type() == xla::C64 && conjugate_x) {
|
||||||
x = builder->Conj(x);
|
x = builder->Conj(x);
|
||||||
}
|
}
|
||||||
if (y_shape->element_type() == xla::C64 && transpose_y) {
|
if (y_shape->element_type() == xla::C64 && conjugate_y) {
|
||||||
y = builder->Conj(y);
|
y = builder->Conj(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,10 @@ namespace tensorflow {
|
|||||||
// viewed as an element of a batch), and arranges the individual results
|
// viewed as an element of a batch), and arranges the individual results
|
||||||
// in a single output tensor of the same batch size. Each of the
|
// in a single output tensor of the same batch size. Each of the
|
||||||
// individual slices can optionally be transposed before multiplication by
|
// individual slices can optionally be transposed before multiplication by
|
||||||
// setting the `transpose_x` or `transpose_y` flag to `true`.
|
// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each
|
||||||
|
// can be elementwise-complex-conjugated by setting the `conjugate_x` or
|
||||||
|
// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both
|
||||||
|
// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`.
|
||||||
//
|
//
|
||||||
// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
|
// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
|
||||||
// and `[..., r_y, c_y]`.
|
// and `[..., r_y, c_y]`.
|
||||||
@ -40,11 +43,10 @@ namespace tensorflow {
|
|||||||
// It is computed as:
|
// It is computed as:
|
||||||
//
|
//
|
||||||
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
|
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
|
||||||
// TODO(phawkins): add an option to take the complex conjugate of the LHS or
|
|
||||||
// RHS.
|
|
||||||
xla::StatusOr<xla::ComputationDataHandle> BatchDot(
|
xla::StatusOr<xla::ComputationDataHandle> BatchDot(
|
||||||
xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
|
xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
|
||||||
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y);
|
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
|
||||||
|
bool conjugate_x = false, bool conjugate_y = false);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -71,11 +71,14 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
|
|||||||
SliceInMinorDims(builder, l, {j + 1, 0}, {n, j}));
|
SliceInMinorDims(builder, l, {j + 1, 0}, {n, j}));
|
||||||
TF_ASSIGN_OR_RETURN(auto r_squared,
|
TF_ASSIGN_OR_RETURN(auto r_squared,
|
||||||
BatchDot(builder, r, r, /*transpose_x=*/false,
|
BatchDot(builder, r, r, /*transpose_x=*/false,
|
||||||
/*transpose_y=*/true));
|
/*transpose_y=*/true, /*conjugate_x=*/false,
|
||||||
|
/*conjugate_y=*/false));
|
||||||
new_d_squared = builder->Sub(new_d_squared, r_squared);
|
new_d_squared = builder->Sub(new_d_squared, r_squared);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false,
|
TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false,
|
||||||
/*transpose_y=*/true));
|
/*transpose_y=*/true,
|
||||||
|
/*conjugate_x=*/false,
|
||||||
|
/*conjugate_y=*/false));
|
||||||
}
|
}
|
||||||
auto new_d_inv = builder->Pow(
|
auto new_d_inv = builder->Pow(
|
||||||
new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5));
|
new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5));
|
||||||
@ -134,7 +137,8 @@ xla::StatusOr<xla::ComputationDataHandle> Cholesky(
|
|||||||
SliceInMinorDims(builder, l, {i, 0}, {i + k, i}));
|
SliceInMinorDims(builder, l, {i, 0}, {i + k, i}));
|
||||||
TF_ASSIGN_OR_RETURN(auto delta,
|
TF_ASSIGN_OR_RETURN(auto delta,
|
||||||
BatchDot(builder, lhs, rhs, /*transpose_x=*/false,
|
BatchDot(builder, lhs, rhs, /*transpose_x=*/false,
|
||||||
/*transpose_y=*/true));
|
/*transpose_y=*/true, /*conjugate_x=*/false,
|
||||||
|
/*conjugate_y=*/false));
|
||||||
TF_ASSIGN_OR_RETURN(auto before,
|
TF_ASSIGN_OR_RETURN(auto before,
|
||||||
SliceInMinorDims(builder, a, {i, i}, {n, i + k}));
|
SliceInMinorDims(builder, a, {i, i}, {n, i + k}));
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -155,6 +159,10 @@ xla::StatusOr<xla::ComputationDataHandle> Cholesky(
|
|||||||
SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
|
SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
|
||||||
TF_ASSIGN_OR_RETURN(auto update,
|
TF_ASSIGN_OR_RETURN(auto update,
|
||||||
TriangularSolve(builder, factorized, panel,
|
TriangularSolve(builder, factorized, panel,
|
||||||
|
/*left_side=*/false,
|
||||||
|
/*lower=*/true,
|
||||||
|
/*transpose_a=*/true,
|
||||||
|
/*conjugate_a=*/false,
|
||||||
/*block_size=*/8));
|
/*block_size=*/8));
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
l, UpdateSliceInMinorDims(builder, l, update, {i + k, i}));
|
l, UpdateSliceInMinorDims(builder, l, update, {i + k, i}));
|
||||||
|
@ -29,6 +29,7 @@ namespace tensorflow {
|
|||||||
// the block size to use.
|
// the block size to use.
|
||||||
// TODO(phawkins): check for negative values on the diagonal and return an
|
// TODO(phawkins): check for negative values on the diagonal and return an
|
||||||
// error, instead of silently yielding NaNs.
|
// error, instead of silently yielding NaNs.
|
||||||
|
// TODO(mattjj): handle the complex Hermitian case
|
||||||
xla::StatusOr<xla::ComputationDataHandle> Cholesky(
|
xla::StatusOr<xla::ComputationDataHandle> Cholesky(
|
||||||
xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
|
xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
|
||||||
int64 block_size = 256);
|
int64 block_size = 256);
|
||||||
|
@ -24,13 +24,15 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
||||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
||||||
xla::ComputationDataHandle b, int64 block_size) {
|
xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
|
||||||
|
bool conjugate_a, int64 block_size) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
|
||||||
builder->GetShape(a));
|
builder->GetShape(a));
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
|
||||||
@ -60,14 +62,15 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
|||||||
batch_dimensions.push_back(a_size);
|
batch_dimensions.push_back(a_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
|
if (xla::ShapeUtil::GetDimension(*a_shape, -1) !=
|
||||||
const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
|
xla::ShapeUtil::GetDimension(*a_shape, -2)) {
|
||||||
if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) {
|
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"The 'a' arguments to TriangularSolve must be square matrices: ",
|
"The 'a' arguments to TriangularSolve must be square matrices: ",
|
||||||
xla::ShapeUtil::HumanString(*a_shape));
|
xla::ShapeUtil::HumanString(*a_shape));
|
||||||
}
|
}
|
||||||
if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) {
|
const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
|
||||||
|
const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
|
||||||
|
if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Arguments to TriangularSolve have incompatible matrix shapes: ",
|
"Arguments to TriangularSolve have incompatible matrix shapes: ",
|
||||||
xla::ShapeUtil::HumanString(*a_shape), " vs ",
|
xla::ShapeUtil::HumanString(*a_shape), " vs ",
|
||||||
@ -89,6 +92,14 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
|||||||
return output;
|
return output;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
|
||||||
|
// is true, otherwise returns its argument.
|
||||||
|
auto maybe_conj = [&](xla::ComputationBuilder* builder,
|
||||||
|
xla::ComputationDataHandle x) {
|
||||||
|
auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
|
||||||
|
return perform_conj ? builder->Conj(x) : x;
|
||||||
|
};
|
||||||
|
|
||||||
std::map<int, xla::Computation> base_computations;
|
std::map<int, xla::Computation> base_computations;
|
||||||
auto get_base_triangular_solve =
|
auto get_base_triangular_solve =
|
||||||
[&](int k) -> xla::StatusOr<xla::Computation*> {
|
[&](int k) -> xla::StatusOr<xla::Computation*> {
|
||||||
@ -103,19 +114,35 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
|||||||
prepend_batch_dims({k, k})),
|
prepend_batch_dims({k, k})),
|
||||||
"a");
|
"a");
|
||||||
|
|
||||||
|
std::array<int64, 2> b_lastd;
|
||||||
|
if (left_side) {
|
||||||
|
b_lastd = {k, n};
|
||||||
|
} else {
|
||||||
|
b_lastd = {m, k};
|
||||||
|
}
|
||||||
auto b_param =
|
auto b_param =
|
||||||
sub->Parameter(1,
|
sub->Parameter(1,
|
||||||
xla::ShapeUtil::MakeShape(b_shape->element_type(),
|
xla::ShapeUtil::MakeShape(b_shape->element_type(),
|
||||||
prepend_batch_dims({m, k})),
|
prepend_batch_dims(b_lastd)),
|
||||||
"b");
|
"b");
|
||||||
|
|
||||||
// TODO(phawkins): it might make sense to use a while loop here, rather
|
// We use a left-looking subroutine on the block diagonal in some common
|
||||||
// than unrolling.
|
// cases, while falling back to a recursive call in unsupported cases. The
|
||||||
// TODO(phawkins): the left-looking variant of the algorithm might be more
|
// left-looking subroutine is written with a While loop and so yields much
|
||||||
// efficient at block size 1.
|
// faster compile times. Moreover, the left-looking variant can give
|
||||||
|
// higher performance on smaller (sub)problems.
|
||||||
|
if (left_side && lower) {
|
||||||
|
TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param,
|
||||||
|
b_param, transpose_a,
|
||||||
|
conjugate_a)
|
||||||
|
.status());
|
||||||
|
} else {
|
||||||
TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
|
TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
|
||||||
|
left_side, lower, transpose_a,
|
||||||
|
conjugate_a,
|
||||||
/*block_size=*/1)
|
/*block_size=*/1)
|
||||||
.status());
|
.status());
|
||||||
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(computation, sub->Build());
|
TF_ASSIGN_OR_RETURN(computation, sub->Build());
|
||||||
}
|
}
|
||||||
@ -129,15 +156,18 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
|||||||
// Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation
|
// Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation
|
||||||
// of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1
|
// of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1
|
||||||
// (2008): 4.
|
// (2008): 4.
|
||||||
|
|
||||||
|
// In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if
|
||||||
|
// conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if
|
||||||
|
// conjugate_a is True.
|
||||||
|
|
||||||
|
if (!left_side && lower == transpose_a) {
|
||||||
|
// for i in range(0, a.shape[-1], block_size):
|
||||||
for (int64 i = 0; i < n; i += block_size) {
|
for (int64 i = 0; i < n; i += block_size) {
|
||||||
int64 k = std::min(block_size, n - i);
|
int64 k = std::min(block_size, n - i);
|
||||||
|
|
||||||
// if k > 1:
|
|
||||||
// output[..., :, i:i+k] = triangular_solve(
|
// output[..., :, i:i+k] = triangular_solve(
|
||||||
// a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right',
|
// a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
|
||||||
// kind='Lower', transpose=True, block_size=1)
|
|
||||||
// else:
|
|
||||||
// output[..., :, i] = b[..., :, i] / a[..., i, i]
|
|
||||||
TF_ASSIGN_OR_RETURN(auto a_slice,
|
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||||
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
||||||
TF_ASSIGN_OR_RETURN(auto b_slice,
|
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||||
@ -148,20 +178,31 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
|||||||
get_base_triangular_solve(k));
|
get_base_triangular_solve(k));
|
||||||
update = builder->Call(*solve, {a_slice, b_slice});
|
update = builder->Call(*solve, {a_slice, b_slice});
|
||||||
} else {
|
} else {
|
||||||
update = builder->Div(b_slice, a_slice);
|
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
|
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
|
||||||
// b[..., :, i+k:] -= np.dot(output[..., :, i:i+k],
|
|
||||||
// np.transpose(..., a[i+k:, i:i+k]))
|
|
||||||
if (i + k < n) {
|
|
||||||
TF_ASSIGN_OR_RETURN(auto a_slice_2,
|
|
||||||
SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
|
|
||||||
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2,
|
|
||||||
/*transpose_x=*/false,
|
|
||||||
/*transpose_y=*/true));
|
|
||||||
|
|
||||||
|
// if i + k < a.shape[-1]:
|
||||||
|
// a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
|
||||||
|
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
|
||||||
|
// b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
|
||||||
|
if (i + k < n) {
|
||||||
|
xla::ComputationDataHandle a_slice_2;
|
||||||
|
if (lower) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
|
||||||
|
} else {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_update,
|
||||||
|
BatchDot(builder, update, a_slice_2,
|
||||||
|
/*transpose_x=*/false,
|
||||||
|
/*transpose_y=*/transpose_a,
|
||||||
|
/*conjugate_x=*/false,
|
||||||
|
/*conjugate_y=*/conjugate_a));
|
||||||
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
||||||
SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
|
SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
|
||||||
b_update = builder->Sub(b_slice_2, b_update);
|
b_update = builder->Sub(b_slice_2, b_update);
|
||||||
@ -169,7 +210,342 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
|||||||
b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
|
b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} else if (left_side && lower != transpose_a) {
|
||||||
|
// for i in range(0, a.shape[-1], block_size):
|
||||||
|
for (int64 i = 0; i < m; i += block_size) {
|
||||||
|
int64 k = std::min(block_size, m - i);
|
||||||
|
|
||||||
|
// output[..., i:i+k, :] = triangular_solve(
|
||||||
|
// a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
|
||||||
|
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||||
|
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||||
|
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
|
||||||
|
xla::ComputationDataHandle update;
|
||||||
|
if (k > 1) {
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
|
||||||
|
get_base_triangular_solve(k));
|
||||||
|
update = builder->Call(*solve, {a_slice, b_slice});
|
||||||
|
} else {
|
||||||
|
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
|
||||||
|
|
||||||
|
// if i + k < a.shape[-1]:
|
||||||
|
// a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
|
||||||
|
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
|
||||||
|
// b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
|
||||||
|
if (i + k < m) {
|
||||||
|
xla::ComputationDataHandle a_slice_2;
|
||||||
|
if (lower) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k}));
|
||||||
|
} else {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
|
||||||
|
/*transpose_x=*/transpose_a,
|
||||||
|
/*transpose_y=*/false,
|
||||||
|
/*conjugate_x=*/conjugate_a,
|
||||||
|
/*conjugate_y=*/false));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
||||||
|
SliceInMinorDims(builder, b, {i + k, 0}, {m, n}));
|
||||||
|
b_update = builder->Sub(b_slice_2, b_update);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (!left_side && lower != transpose_a) {
|
||||||
|
// for i in reversed(range(0, a.shape[-1], block_size)):
|
||||||
|
const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size;
|
||||||
|
for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
|
||||||
|
int64 k = std::min(block_size, n - i);
|
||||||
|
|
||||||
|
// output[..., :, i:i+k] triangular_solve(
|
||||||
|
// a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
|
||||||
|
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||||
|
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||||
|
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
|
||||||
|
xla::ComputationDataHandle update;
|
||||||
|
if (k > 1) {
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
|
||||||
|
get_base_triangular_solve(k));
|
||||||
|
update = builder->Call(*solve, {a_slice, b_slice});
|
||||||
|
} else {
|
||||||
|
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
|
||||||
|
|
||||||
|
// if i - k >= 0:
|
||||||
|
// a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
|
||||||
|
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
|
||||||
|
// b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
|
||||||
|
if (i - k >= 0) {
|
||||||
|
xla::ComputationDataHandle a_slice_2;
|
||||||
|
if (lower) {
|
||||||
|
TF_ASSIGN_OR_RETURN(a_slice_2,
|
||||||
|
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
|
||||||
|
} else {
|
||||||
|
TF_ASSIGN_OR_RETURN(a_slice_2,
|
||||||
|
SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_update,
|
||||||
|
BatchDot(builder, update, a_slice_2,
|
||||||
|
/*transpose_x=*/false,
|
||||||
|
/*transpose_y=*/transpose_a,
|
||||||
|
/*conjugate_x=*/false,
|
||||||
|
/*conjugate_y=*/conjugate_a));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
||||||
|
SliceInMinorDims(builder, b, {0, 0}, {m, i}));
|
||||||
|
b_update = builder->Sub(b_slice_2, b_update);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else { // left_side && lower == transpose_a
|
||||||
|
// for i in reversed(range(0, a.shape[-1], block_size)):
|
||||||
|
const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size;
|
||||||
|
for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
|
||||||
|
int64 k = std::min(block_size, m - i);
|
||||||
|
|
||||||
|
// output[..., i:i+k, :] triangular_solve(
|
||||||
|
// a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
|
||||||
|
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||||
|
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||||
|
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
|
||||||
|
xla::ComputationDataHandle update;
|
||||||
|
if (k > 1) {
|
||||||
|
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
|
||||||
|
get_base_triangular_solve(k));
|
||||||
|
update = builder->Call(*solve, {a_slice, b_slice});
|
||||||
|
} else {
|
||||||
|
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
|
||||||
|
|
||||||
|
// if i - k >= 0:
|
||||||
|
// a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
|
||||||
|
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
|
||||||
|
// b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
|
||||||
|
if (i - k >= 0) {
|
||||||
|
xla::ComputationDataHandle a_slice_2;
|
||||||
|
if (lower) {
|
||||||
|
TF_ASSIGN_OR_RETURN(a_slice_2,
|
||||||
|
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
|
||||||
|
} else {
|
||||||
|
TF_ASSIGN_OR_RETURN(a_slice_2,
|
||||||
|
SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
|
||||||
|
/*transpose_x=*/transpose_a,
|
||||||
|
/*transpose_y=*/false,
|
||||||
|
/*conjugate_x=*/conjugate_a,
|
||||||
|
/*conjugate_y=*/false));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
||||||
|
SliceInMinorDims(builder, b, {0, 0}, {i, n}));
|
||||||
|
b_update = builder->Sub(b_slice_2, b_update);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
|
||||||
|
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
||||||
|
const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) {
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
|
||||||
|
builder->GetShape(a));
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
|
||||||
|
builder->GetShape(b));
|
||||||
|
const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
|
||||||
|
const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
|
||||||
|
const int64 ndims = xla::ShapeUtil::Rank(*a_shape);
|
||||||
|
|
||||||
|
std::vector<int64> batch_dimensions;
|
||||||
|
for (int i = 0; i < ndims - 2; ++i) {
|
||||||
|
int64 a_size = a_shape->dimensions(i);
|
||||||
|
batch_dimensions.push_back(a_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
|
||||||
|
std::vector<int64> output(ndims);
|
||||||
|
std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
|
||||||
|
std::copy(indices.begin(), indices.end(),
|
||||||
|
output.begin() + batch_dimensions.size());
|
||||||
|
return output;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto maybe_conj = [&](xla::ComputationBuilder* builder,
|
||||||
|
xla::ComputationDataHandle x) {
|
||||||
|
auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
|
||||||
|
return perform_conj ? builder->Conj(x) : x;
|
||||||
|
};
|
||||||
|
|
||||||
|
// The main computation is performed in a While loop.
|
||||||
|
|
||||||
|
// Allocate the output and set its first or last row,
|
||||||
|
// output = np.zeros_like(b)
|
||||||
|
// if transpose_a:
|
||||||
|
// output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
|
||||||
|
// else:
|
||||||
|
// output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
|
||||||
|
xla::ComputationDataHandle output = Zeros(builder, *b_shape);
|
||||||
|
{
|
||||||
|
auto i = transpose_a ? m - 1 : 0;
|
||||||
|
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||||
|
SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1}));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||||
|
SliceInMinorDims(builder, b, {i, 0}, {i + 1, n}));
|
||||||
|
auto update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the initial loop carry tuple,
|
||||||
|
// if transpose_a:
|
||||||
|
// init = (m-2, output, a, b)
|
||||||
|
// else:
|
||||||
|
// init = (1, output, a, b)
|
||||||
|
std::vector<xla::Shape> tuple_shapes = {
|
||||||
|
// The loop iteration counter is a scalar, incremented each iteration.
|
||||||
|
xla::ShapeUtil::MakeShape(xla::S32, {}),
|
||||||
|
// The output has the shape of b, with one row updated each iteration.
|
||||||
|
*b_shape,
|
||||||
|
// The coefficient matrix a is a loop invariant.
|
||||||
|
*a_shape,
|
||||||
|
// The right-hand-side matrix b is a loop invariant.
|
||||||
|
*b_shape};
|
||||||
|
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
|
||||||
|
auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1);
|
||||||
|
auto init = builder->Tuple({init_i, output, a, b});
|
||||||
|
|
||||||
|
// Construct the loop condition function,
|
||||||
|
// def cond_fun(loop_carry):
|
||||||
|
// i, output, a, b = loop_carry
|
||||||
|
// return i >= 0 if transpose_a else i < m
|
||||||
|
std::unique_ptr<xla::ComputationBuilder> condb =
|
||||||
|
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
|
||||||
|
{
|
||||||
|
auto i = condb->GetTupleElement(
|
||||||
|
condb->Parameter(0, tuple_shape,
|
||||||
|
"TriangularSolveLeftLookingWhileTuple"),
|
||||||
|
0);
|
||||||
|
if (transpose_a) {
|
||||||
|
condb->Ge(i, condb->ConstantR0<int32>(0));
|
||||||
|
} else {
|
||||||
|
condb->Lt(i, condb->ConstantR0<int32>(m));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
|
||||||
|
|
||||||
|
// Construct the loop body function,
|
||||||
|
// def body_fun(loop_carry):
|
||||||
|
// i, output, a, b = loop_carry
|
||||||
|
// if transpose_a:
|
||||||
|
// a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2)
|
||||||
|
// else:
|
||||||
|
// a_row = a[..., i:i+1, :i]
|
||||||
|
// result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :])
|
||||||
|
// output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
|
||||||
|
// if transpose_a:
|
||||||
|
// return (i - 1, output, a, b)
|
||||||
|
// else:
|
||||||
|
// return (i + 1, output, a, b)
|
||||||
|
// We have to do some extra FLOPs propagating zeros in the matrix multiply
|
||||||
|
// because we can't have the size of its arguments depend on the loop counter.
|
||||||
|
std::unique_ptr<xla::ComputationBuilder> bodyb =
|
||||||
|
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
|
||||||
|
{
|
||||||
|
auto input_tuple = bodyb->Parameter(0, tuple_shape,
|
||||||
|
"TriangularSolveLeftLookingWhileTuple");
|
||||||
|
|
||||||
|
// i, output, a, b = loop_carry
|
||||||
|
auto i = bodyb->GetTupleElement(input_tuple, 0);
|
||||||
|
auto body_out = bodyb->GetTupleElement(input_tuple, 1);
|
||||||
|
auto body_a = bodyb->GetTupleElement(input_tuple, 2);
|
||||||
|
auto body_b = bodyb->GetTupleElement(input_tuple, 3);
|
||||||
|
auto zero = bodyb->ConstantR0<int32>(0);
|
||||||
|
|
||||||
|
// Set up some helper functions.
|
||||||
|
auto prepend_zeros = [&](std::array<xla::ComputationDataHandle, 2> starts) {
|
||||||
|
auto zero = bodyb->Reshape(bodyb->ConstantR0<int32>(0), {1});
|
||||||
|
std::vector<xla::ComputationDataHandle> padded_starts(ndims, zero);
|
||||||
|
padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1});
|
||||||
|
padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1});
|
||||||
|
return bodyb->ConcatInDim(padded_starts, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto dynamic_slice = [&](xla::ComputationDataHandle x,
|
||||||
|
std::array<xla::ComputationDataHandle, 2> starts,
|
||||||
|
std::array<int64, 2> sizes) {
|
||||||
|
auto padded_starts = prepend_zeros(starts);
|
||||||
|
auto padded_sizes = prepend_batch_dims(sizes);
|
||||||
|
return bodyb->DynamicSlice(x, padded_starts, padded_sizes);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto update = [&](xla::ComputationDataHandle x,
|
||||||
|
xla::ComputationDataHandle update,
|
||||||
|
std::array<xla::ComputationDataHandle, 2> starts) {
|
||||||
|
auto padded_starts = prepend_zeros(starts);
|
||||||
|
return bodyb->DynamicUpdateSlice(x, update, padded_starts);
|
||||||
|
};
|
||||||
|
|
||||||
|
// We'd like to implement this:
|
||||||
|
// if transpose_a:
|
||||||
|
// a_row = T(a[..., i+1:, i:i+1])
|
||||||
|
// result_row = (b[..., i:i+1, :]
|
||||||
|
// - np.matmul(a_row, body_out[..., i+1:, :]))
|
||||||
|
// else:
|
||||||
|
// result_row = (b[..., i:i+1, :]
|
||||||
|
// - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :]))
|
||||||
|
// But since we can't have intermediate array sizes depend on the loop
|
||||||
|
// counter, we instead exploit the fact that we initialized the output to
|
||||||
|
// all zeros and use that as zero-padding (doing unnecessary FLOPs).
|
||||||
|
xla::ComputationDataHandle a_row;
|
||||||
|
if (transpose_a) {
|
||||||
|
a_row = dynamic_slice(body_a, {zero, i}, {m, 1});
|
||||||
|
} else {
|
||||||
|
a_row = dynamic_slice(body_a, {i, zero}, {1, m});
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out,
|
||||||
|
/*transpose_x=*/transpose_a,
|
||||||
|
/*transpose_y=*/false,
|
||||||
|
/*conjugate_x=*/conjugate_a,
|
||||||
|
/*conjugate_y=*/false));
|
||||||
|
auto result_row =
|
||||||
|
bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update);
|
||||||
|
|
||||||
|
// body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
|
||||||
|
auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1});
|
||||||
|
auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt));
|
||||||
|
body_out = update(body_out, div_result, {i, zero});
|
||||||
|
|
||||||
|
// if transpose_a:
|
||||||
|
// return (i - 1, body_out, a, b)
|
||||||
|
// else:
|
||||||
|
// return (i + 1, body_out, a, b)
|
||||||
|
auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? -1 : 1));
|
||||||
|
bodyb->Tuple({next_i, body_out, body_a, body_b});
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
|
||||||
|
|
||||||
|
// Construct the While loop and return the result,
|
||||||
|
// return while_loop(cond_fun, body_fun, init)[1]
|
||||||
|
auto triangular_solve_left_looking_while = builder->While(cond, body, init);
|
||||||
|
return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -21,25 +21,50 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Solves systems of linear equations with upper or lower triangular matrices by
|
// Solves systems of linear equations with lower or upper triangular coefficient
|
||||||
// backsubstitution.
|
// matrices by forward- or back-substitution. Broadcasting along leading
|
||||||
|
// dimensions, this routine solves one of the matrix systems
|
||||||
|
// `op(a) * x = b`, or `x * op(a) = b`,
|
||||||
|
// for the variable `x` given `a` and `b`, where `op(a)` is either
|
||||||
|
// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`.
|
||||||
|
// That is, the innermost matrices in the output satisfy a scalar system
|
||||||
|
// depending on the value of the value of (left_side, transpose_a, conjugate_a)
|
||||||
|
// according to:
|
||||||
|
// (F, F, F) => `output[..., i, k] a[..., k, j] = b[..., i, j]`,
|
||||||
|
// (F, F, T) => `output[..., i, k] a*[..., k, j] = b[..., i, j]`,
|
||||||
|
// (F, T, F) => `output[..., i, k] a[..., j, k] = b[..., i, j]`,
|
||||||
|
// (F, T, T) => `output[..., i, k] a*[..., j, k] = b[..., i, j]`,
|
||||||
|
// (T, F, F) => ` a[..., i, k] output[..., k, j] = b[..., i, j]`,
|
||||||
|
// (T, F, T) => `a*[..., i, k] output[..., k, j] = b[..., i, j]`,
|
||||||
|
// (T, T, F) => ` a[..., i, k] output[..., j, k] = b[..., i, j]`,
|
||||||
|
// (T, T, T) => `a*[..., i, k] output[..., j, k] = b[..., i, j]`,
|
||||||
|
// where * denotes complex conjugation and where the index `k` is summed over.
|
||||||
//
|
//
|
||||||
// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
|
// `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
|
||||||
// square matrices. The strictly upper triangular part of each inner-most matrix
|
// square matrices. If lower is true (false), then the strictly upper (lower)
|
||||||
// is assumed to be zero and not accessed.
|
// triangular part of each innermost matrix in `a` is assumed to be zero and is
|
||||||
// `b` is a tensor of shape `[..., M, K]`.
|
// not accessed.
|
||||||
//
|
// `b` is a tensor of shape `[..., M, K]` if left_side is true, otherwise a
|
||||||
// The innermost matrices in the output satisfy matrix equations
|
// tensor of shape `[..., K, M]`.
|
||||||
// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`.
|
// `left_side` is a boolean, indicating whether to solve a system of the form
|
||||||
|
// op(a) * x = b (true) or x * op(a) = b (false).
|
||||||
|
// `lower` is a boolean, indicating whether the argument `a` is lower-triangular
|
||||||
|
// (true) or upper-triangular (false).
|
||||||
|
// `transpose_a` is a boolean indicating whether the matrix `a` is transposed.
|
||||||
|
// `conjugate_a` is a boolean indicating whether the entries of `a` are complex
|
||||||
|
// conjugated (independently of whether they are transposed), so that when both
|
||||||
|
// transpose_a and conjugate_a are true the effect is a Hermitian adjoint.
|
||||||
//
|
//
|
||||||
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
|
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
|
||||||
// blocking is used.
|
// blocking is used.
|
||||||
// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right,
|
|
||||||
// kind=lower, and transposed_a=true. Implement the other possible combinations
|
|
||||||
// of side, kind and transposed_a.
|
|
||||||
xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
||||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
||||||
xla::ComputationDataHandle b, int64 block_size = 256);
|
xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
|
||||||
|
bool conjugate_a, int64 block_size = 256);
|
||||||
|
|
||||||
|
xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
|
||||||
|
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
||||||
|
const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -27,32 +27,68 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using TriangularSolveTest = xla::ClientLibraryTestBase;
|
using TriangularSolveTest = xla::ClientLibraryTestBase;
|
||||||
|
using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase;
|
||||||
|
using complex64 = xla::complex64;
|
||||||
|
|
||||||
XLA_TEST_F(TriangularSolveTest, Simple) {
|
xla::Array2D<float> AValsLower() {
|
||||||
|
return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}};
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Array2D<float> AValsUpper() {
|
||||||
|
return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}};
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Array2D<float> BValsRight() {
|
||||||
|
return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Array2D<float> BValsLeft() {
|
||||||
|
return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Array2D<complex64> AValsLowerComplex() {
|
||||||
|
return {{2, 0, 0, 0},
|
||||||
|
{complex64(3, 1), 6, 0, 0},
|
||||||
|
{4, complex64(7, 2), 9, 0},
|
||||||
|
{5, 8, complex64(10, 3), 11}};
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Array2D<complex64> AValsUpperComplex() {
|
||||||
|
return {{2, 3, complex64(4, 3), 5},
|
||||||
|
{0, 6, complex64(7, 2), 8},
|
||||||
|
{0, 0, complex64(9, 1), 10},
|
||||||
|
{0, 0, 0, 11}};
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Array2D<complex64> BValsRightComplex() {
|
||||||
|
return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Array2D<complex64> BValsLeftComplex() {
|
||||||
|
return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Array2D<float> AValsFull() {
|
||||||
|
return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
|
||||||
xla::ComputationBuilder builder(client_, TestName());
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
xla::Array2D<float> a_vals({
|
|
||||||
{2, 0, 0, 0},
|
|
||||||
{3, 6, 0, 0},
|
|
||||||
{4, 7, 9, 0},
|
|
||||||
{5, 8, 10, 11},
|
|
||||||
});
|
|
||||||
xla::Array2D<float> b_vals({
|
|
||||||
{1, 2, 3, 4},
|
|
||||||
{5, 6, 7, 8},
|
|
||||||
{9, 10, 11, 12},
|
|
||||||
});
|
|
||||||
|
|
||||||
xla::ComputationDataHandle a, b;
|
xla::ComputationDataHandle a, b;
|
||||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||||
auto b_data = CreateR2Parameter<float>(b_vals, 1, "b", &builder, &b);
|
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
|
||||||
auto result = TriangularSolve(&builder, a, b, /*block_size=*/2);
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/false, /*lower=*/true,
|
||||||
|
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||||
|
/*block_size=*/2);
|
||||||
TF_ASSERT_OK(result.status());
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
xla::Array2D<float> expected({
|
xla::Array2D<float> expected({
|
||||||
@ -62,7 +98,267 @@ XLA_TEST_F(TriangularSolveTest, Simple) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
xla::ErrorSpec(2e-3, 2e-3));
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||||
|
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/false, /*lower=*/true,
|
||||||
|
/*transpose_a=*/false, /*conjugate_a=*/false,
|
||||||
|
/*block_size=*/2);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
|
||||||
|
{0.64393939, 0.06565657, -0.03030303, 0.72727273},
|
||||||
|
{1.4520202, 0.2003367, 0.01010101, 1.09090909},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
|
||||||
|
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/false, /*lower=*/false,
|
||||||
|
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||||
|
/*block_size=*/2);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
|
||||||
|
{0.64393939, 0.06565657, -0.03030303, 0.72727273},
|
||||||
|
{1.4520202, 0.2003367, 0.01010101, 1.09090909},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
|
||||||
|
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/false, /*lower=*/false,
|
||||||
|
/*transpose_a=*/false, /*conjugate_a=*/false,
|
||||||
|
/*block_size=*/2);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{0.5, 0.08333334, 0.04629629, 0.03367003},
|
||||||
|
{2.5, -0.25, -0.1388889, -0.1010101},
|
||||||
|
{4.5, -0.58333331, -0.32407406, -0.23569024},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||||
|
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/true, /*lower=*/true,
|
||||||
|
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||||
|
/*block_size=*/2);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{-0.89646465, -0.69444444, -0.49242424},
|
||||||
|
{-0.27441077, -0.24074074, -0.20707071},
|
||||||
|
{-0.23232323, -0.22222222, -0.21212121},
|
||||||
|
{0.90909091, 1., 1.09090909},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||||
|
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/true, /*lower=*/true,
|
||||||
|
/*transpose_a=*/false, /*conjugate_a=*/false,
|
||||||
|
/*block_size=*/2);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{0.5, 1.0, 1.5},
|
||||||
|
{0.41666667, 0.33333333, 0.25},
|
||||||
|
{0.23148148, 0.18518519, 0.13888889},
|
||||||
|
{0.16835017, 0.13468013, 0.1010101},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
|
||||||
|
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/true, /*lower=*/false,
|
||||||
|
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||||
|
/*block_size=*/2);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{0.5, 1.0, 1.5},
|
||||||
|
{0.41666667, 0.33333333, 0.25},
|
||||||
|
{0.23148148, 0.18518519, 0.13888889},
|
||||||
|
{0.16835017, 0.13468013, 0.1010101},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
|
||||||
|
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/true, /*lower=*/false,
|
||||||
|
/*transpose_a=*/false, /*conjugate_a=*/false,
|
||||||
|
/*block_size=*/2);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{-0.89646465, -0.69444444, -0.49242424},
|
||||||
|
{-0.27441077, -0.24074074, -0.20707071},
|
||||||
|
{-0.23232323, -0.22222222, -0.21212121},
|
||||||
|
{0.90909091, 1., 1.09090909},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data =
|
||||||
|
CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
|
||||||
|
auto b_data =
|
||||||
|
CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/false, /*lower=*/true,
|
||||||
|
/*transpose_a=*/true, /*conjugate_a=*/true,
|
||||||
|
/*block_size=*/2);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<complex64> expected({
|
||||||
|
{0.5, complex64(0.08333333, 0.08333333),
|
||||||
|
complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)},
|
||||||
|
{2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963),
|
||||||
|
complex64(0.08670034, -0.02104377)},
|
||||||
|
{4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296),
|
||||||
|
complex64(0.11026936, -0.03114478)},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<complex64>(&builder, expected,
|
||||||
|
{a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data =
|
||||||
|
CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
|
||||||
|
auto b_data =
|
||||||
|
CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolve(&builder, a, b,
|
||||||
|
/*left_side=*/true, /*lower=*/false,
|
||||||
|
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||||
|
/*block_size=*/2);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<complex64> expected({
|
||||||
|
{0.5, 1., 1.5},
|
||||||
|
{0.41666667, 0.33333333, 0.25},
|
||||||
|
{complex64(0.20020325, -2.81504065e-01),
|
||||||
|
complex64(0.13821138, -4.22764228e-01),
|
||||||
|
complex64(0.07621951, -5.64024390e-01)},
|
||||||
|
{complex64(0.19678492, 2.55912786e-01),
|
||||||
|
complex64(0.17738359, 3.84331116e-01),
|
||||||
|
complex64(0.15798226, 5.12749446e-01)},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<complex64>(&builder, expected,
|
||||||
|
{a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||||
|
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolveLeftLooking(&builder, a, b,
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*conjugate_a=*/false);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{0.5, 1.0, 1.5},
|
||||||
|
{0.41666667, 0.33333333, 0.25},
|
||||||
|
{0.23148148, 0.18518519, 0.13888889},
|
||||||
|
{0.16835017, 0.13468013, 0.1010101},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
|
||||||
|
xla::ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
xla::ComputationDataHandle a, b;
|
||||||
|
auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
|
||||||
|
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||||
|
auto result = TriangularSolveLeftLooking(&builder, a, b,
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*conjugate_a=*/false);
|
||||||
|
TF_ASSERT_OK(result.status());
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{0.5, 1.0, 1.5},
|
||||||
|
{0.41666667, 0.33333333, 0.25},
|
||||||
|
{0.23148148, 0.18518519, 0.13888889},
|
||||||
|
{0.16835017, 0.13468013, 0.1010101},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||||
|
xla::ErrorSpec(1e-2, 1e-2));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -107,4 +107,15 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
|
|||||||
return UpdateSlice(builder, x, update, padded_start);
|
return UpdateSlice(builder, x, update, padded_start);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
|
||||||
|
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
|
||||||
|
const int64 n_dims = xla::ShapeUtil::Rank(*shape);
|
||||||
|
TF_RET_CHECK(n_dims >= 2);
|
||||||
|
std::vector<int64> permutation(n_dims);
|
||||||
|
std::iota(permutation.begin(), permutation.end(), 0);
|
||||||
|
std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
|
||||||
|
return builder->Transpose(x, permutation);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -49,6 +49,10 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
|
|||||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
|
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
|
||||||
const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
|
const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
|
||||||
|
|
||||||
|
// Transposes a stack of matrices `x` by swapping the last two dimensions.
|
||||||
|
xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
|
||||||
|
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
|
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
|
||||||
|
@ -241,9 +241,7 @@ Status CreateXlaArgs(const Graph& graph,
|
|||||||
XlaCompiler::Argument arg;
|
XlaCompiler::Argument arg;
|
||||||
arg.kind = XlaCompiler::Argument::kParameter;
|
arg.kind = XlaCompiler::Argument::kParameter;
|
||||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
|
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
|
||||||
TensorShape shape;
|
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
|
||||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
|
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
|
|
||||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
|
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
|
||||||
xla_args->push_back(arg);
|
xla_args->push_back(arg);
|
||||||
}
|
}
|
||||||
|
@ -66,13 +66,14 @@ Status CheckSignature(const DataTypeVector& types,
|
|||||||
|
|
||||||
bool XlaCompiler::Argument::operator==(
|
bool XlaCompiler::Argument::operator==(
|
||||||
const XlaCompiler::Argument& other) const {
|
const XlaCompiler::Argument& other) const {
|
||||||
if (std::tie(kind, resource_kind, type, name, tensor_array_size,
|
if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size,
|
||||||
tensor_array_gradients) !=
|
tensor_array_gradients) !=
|
||||||
std::tie(other.kind, other.resource_kind, other.type, other.name,
|
std::tie(other.kind, other.resource_kind, other.type, other.name,
|
||||||
other.tensor_array_size, other.tensor_array_gradients)) {
|
other.initialized, other.tensor_array_size,
|
||||||
|
other.tensor_array_gradients)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!xla::ShapeUtil::Equal(shape, other.shape)) {
|
if (shape != other.shape) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (constant_value.shape() != other.constant_value.shape()) {
|
if (constant_value.shape() != other.constant_value.shape()) {
|
||||||
@ -230,6 +231,64 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes the XLA shape for argument 'arg'.
|
||||||
|
/*static*/ Status XlaCompiler::XLAShapeForArgument(
|
||||||
|
const XlaCompiler::Argument& arg, xla::Shape* xla_shape) {
|
||||||
|
switch (arg.kind) {
|
||||||
|
case XlaCompiler::Argument::kConstant:
|
||||||
|
return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
|
||||||
|
xla_shape);
|
||||||
|
case XlaCompiler::Argument::kParameter:
|
||||||
|
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
|
||||||
|
case XlaCompiler::Argument::kResource: {
|
||||||
|
TF_RET_CHECK(arg.initialized);
|
||||||
|
|
||||||
|
switch (arg.resource_kind) {
|
||||||
|
case XlaResource::kVariable:
|
||||||
|
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
|
||||||
|
case XlaResource::kTensorArray: {
|
||||||
|
if (arg.tensor_array_size < 0) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Negative tensor_array_size in XLAShapeForArgument");
|
||||||
|
}
|
||||||
|
TensorShape shape;
|
||||||
|
shape.AddDim(arg.tensor_array_size);
|
||||||
|
shape.AppendShape(arg.shape);
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
|
||||||
|
|
||||||
|
if (!arg.tensor_array_gradients.empty()) {
|
||||||
|
std::vector<xla::Shape> tuple_shape(
|
||||||
|
arg.tensor_array_gradients.size() + 1, *xla_shape);
|
||||||
|
*xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
case XlaResource::kStack: {
|
||||||
|
if (arg.tensor_array_size < 0) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Negative tensor_array_size in XLAShapeForArgument");
|
||||||
|
}
|
||||||
|
TensorShape shape;
|
||||||
|
shape.AddDim(arg.tensor_array_size);
|
||||||
|
shape.AppendShape(arg.shape);
|
||||||
|
xla::Shape buffer_shape;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
|
||||||
|
*xla_shape = xla::ShapeUtil::MakeTupleShape(
|
||||||
|
{buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
case XlaResource::kInvalid:
|
||||||
|
return errors::Internal(
|
||||||
|
"Invalid resource type in XLAShapeForArgument()");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case XlaCompiler::Argument::kInvalid:
|
||||||
|
return errors::Internal("Invalid argument type in XLAShapeForArgument()");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||||
@ -275,8 +334,9 @@ Status BuildArguments(const Graph& graph,
|
|||||||
|
|
||||||
// Argument numbers of arguments and resources that are to be passed to the
|
// Argument numbers of arguments and resources that are to be passed to the
|
||||||
// XLA computation as runtime parameters.
|
// XLA computation as runtime parameters.
|
||||||
std::vector<int> parameters, resources;
|
input_mapping->clear();
|
||||||
parameters.reserve(args.size());
|
input_mapping->reserve(args.size());
|
||||||
|
std::vector<int> resources;
|
||||||
resources.reserve(args.size());
|
resources.reserve(args.size());
|
||||||
|
|
||||||
// Fills in constant arguments, and computes non-constant argument order.
|
// Fills in constant arguments, and computes non-constant argument order.
|
||||||
@ -290,18 +350,20 @@ Status BuildArguments(const Graph& graph,
|
|||||||
// TODO(phawkins): this code assumes that resource arguments do not
|
// TODO(phawkins): this code assumes that resource arguments do not
|
||||||
// alias.
|
// alias.
|
||||||
XlaResource* resource;
|
XlaResource* resource;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(context->CreateResource(
|
||||||
context->CreateResource(arg.resource_kind, i, arg.name, arg.type,
|
arg.resource_kind, i, arg.name, arg.type, arg.shape,
|
||||||
xla::ComputationDataHandle(), &resource));
|
xla::ComputationDataHandle(),
|
||||||
resource->set_tensor_array_size(arg.tensor_array_size);
|
/*tensor_array_size=*/arg.tensor_array_size,
|
||||||
|
/*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
|
||||||
arg_expression.set_resource(resource);
|
arg_expression.set_resource(resource);
|
||||||
if (arg.initialized) {
|
if (arg.initialized) {
|
||||||
resources.push_back(i);
|
resources.push_back(i);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case XlaCompiler::Argument::kParameter:
|
case XlaCompiler::Argument::kParameter: {
|
||||||
parameters.push_back(i);
|
input_mapping->push_back(i);
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
case XlaCompiler::Argument::kConstant:
|
case XlaCompiler::Argument::kConstant:
|
||||||
arg_expression.set_constant_value(arg.constant_value);
|
arg_expression.set_constant_value(arg.constant_value);
|
||||||
break;
|
break;
|
||||||
@ -312,19 +374,17 @@ Status BuildArguments(const Graph& graph,
|
|||||||
|
|
||||||
// Append parameters containing variable values after the other runtime
|
// Append parameters containing variable values after the other runtime
|
||||||
// parameters.
|
// parameters.
|
||||||
parameters.insert(parameters.end(), resources.begin(), resources.end());
|
input_mapping->insert(input_mapping->end(), resources.begin(),
|
||||||
if (parameters.empty()) {
|
resources.end());
|
||||||
|
if (input_mapping->empty()) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<xla::Shape> arg_shapes;
|
std::vector<xla::Shape> arg_shapes(input_mapping->size());
|
||||||
arg_shapes.reserve(parameters.size());
|
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||||
input_mapping->resize(parameters.size());
|
|
||||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
|
||||||
const XlaCompiler::Argument& arg = args[parameters[i]];
|
|
||||||
// Computes the shapes of non-constant arguments.
|
// Computes the shapes of non-constant arguments.
|
||||||
arg_shapes.push_back(arg.shape);
|
TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument(
|
||||||
(*input_mapping)[i] = parameters[i];
|
args[(*input_mapping)[i]], &arg_shapes[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_tuple_arg) {
|
if (use_tuple_arg) {
|
||||||
@ -354,13 +414,13 @@ Status BuildArguments(const Graph& graph,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build parameter handles for non-constant arguments.
|
// Build parameter handles for non-constant arguments.
|
||||||
std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
|
std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
|
||||||
if (use_tuple_arg) {
|
if (use_tuple_arg) {
|
||||||
xla::ComputationDataHandle tuple;
|
xla::ComputationDataHandle tuple;
|
||||||
if (is_entry_computation) {
|
if (is_entry_computation) {
|
||||||
xla::OpSharding tuple_sharding;
|
xla::OpSharding tuple_sharding;
|
||||||
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
|
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
|
||||||
for (int64 parameter : parameters) {
|
for (int64 parameter : *input_mapping) {
|
||||||
const int core = (*arg_cores)[parameter];
|
const int core = (*arg_cores)[parameter];
|
||||||
const int root_device = 0;
|
const int root_device = 0;
|
||||||
*tuple_sharding.add_tuple_shardings() =
|
*tuple_sharding.add_tuple_shardings() =
|
||||||
@ -373,16 +433,16 @@ Status BuildArguments(const Graph& graph,
|
|||||||
} else {
|
} else {
|
||||||
tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
|
tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
|
||||||
}
|
}
|
||||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||||
const int core = (*arg_cores)[parameters[i]];
|
const int core = (*arg_cores)[input_mapping->at(i)];
|
||||||
xla::ScopedShardingAssignment assign_sharding(
|
xla::ScopedShardingAssignment assign_sharding(
|
||||||
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
|
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
|
||||||
: xla::sharding_builder::AssignDevice(core));
|
: xla::sharding_builder::AssignDevice(core));
|
||||||
arg_handles[i] = builder->GetTupleElement(tuple, i);
|
arg_handles[i] = builder->GetTupleElement(tuple, i);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||||
const int core = (*arg_cores)[parameters[i]];
|
const int core = (*arg_cores)[input_mapping->at(i)];
|
||||||
xla::ScopedShardingAssignment assign_sharding(
|
xla::ScopedShardingAssignment assign_sharding(
|
||||||
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
|
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
|
||||||
: xla::sharding_builder::AssignDevice(core));
|
: xla::sharding_builder::AssignDevice(core));
|
||||||
@ -393,19 +453,18 @@ Status BuildArguments(const Graph& graph,
|
|||||||
|
|
||||||
// Fill in the handles in non-constant arguments.
|
// Fill in the handles in non-constant arguments.
|
||||||
VLOG(2) << "XLA computation inputs:";
|
VLOG(2) << "XLA computation inputs:";
|
||||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||||
const XlaCompiler::Argument& arg = args[parameters[i]];
|
const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
|
||||||
VLOG(2) << " XLA arg " << i
|
VLOG(2) << " XLA arg " << i
|
||||||
<< " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
|
<< " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
|
||||||
<< " name: " << arg.name << " TF arg " << parameters[i];
|
<< " name: " << arg.name << " TF arg " << input_mapping->at(i);
|
||||||
XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
|
XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)];
|
||||||
switch (arg.kind) {
|
switch (arg.kind) {
|
||||||
case XlaCompiler::Argument::kResource: {
|
case XlaCompiler::Argument::kResource: {
|
||||||
TF_RET_CHECK(arg.initialized);
|
TF_RET_CHECK(arg.initialized);
|
||||||
XlaResource* resource = arg_expression.resource();
|
XlaResource* resource = arg_expression.resource();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
|
||||||
resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i],
|
arg_handles[i], builder));
|
||||||
/*reset_initial_values=*/true, builder));
|
|
||||||
VLOG(2) << " resource: num_gradients: "
|
VLOG(2) << " resource: num_gradients: "
|
||||||
<< arg.tensor_array_gradients.size();
|
<< arg.tensor_array_gradients.size();
|
||||||
break;
|
break;
|
||||||
@ -486,6 +545,7 @@ Status BuildComputation(
|
|||||||
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
||||||
update.input_index = resource->arg_num();
|
update.input_index = resource->arg_num();
|
||||||
update.type = resource->type();
|
update.type = resource->type();
|
||||||
|
update.shape = resource->shape();
|
||||||
update.modified = modified;
|
update.modified = modified;
|
||||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||||
@ -616,13 +676,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
|||||||
++computation_output;
|
++computation_output;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (std::vector<ResourceUpdate>::size_type i = 0;
|
|
||||||
i < result->resource_updates.size(); ++i) {
|
|
||||||
result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
|
|
||||||
result->xla_output_shape, computation_output);
|
|
||||||
++computation_output;
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,9 +104,17 @@ class XlaCompiler {
|
|||||||
// is the type of the variable's value, not DT_RESOURCE.
|
// is the type of the variable's value, not DT_RESOURCE.
|
||||||
DataType type;
|
DataType type;
|
||||||
|
|
||||||
// The shape of the argument. If the argument is a resource, this is the
|
// The shape of the argument. For:
|
||||||
// shape of the resource's value.
|
// * a parameter: the shape of the parameter.
|
||||||
xla::Shape shape;
|
// * a constant: ignored; the shape given by constant_value is used
|
||||||
|
// instead.
|
||||||
|
// * an uninitialized resource: ignored. We don't yet know the shape of an
|
||||||
|
// uninitialized resource (otherwise we would have initialized it!)
|
||||||
|
// * an initialized variable: the shape of the variable's value.
|
||||||
|
// * an initialized TensorArray or Stack resource: the shape of an entry in
|
||||||
|
// the TensorArray/Stack. Note this is the size of a single entry, not the
|
||||||
|
// XLA data structure that represents the complete stack/array.
|
||||||
|
TensorShape shape;
|
||||||
|
|
||||||
// The value of the argument, if it is a compile-time constant. Must be a
|
// The value of the argument, if it is a compile-time constant. Must be a
|
||||||
// host-memory tensor.
|
// host-memory tensor.
|
||||||
@ -175,8 +183,9 @@ class XlaCompiler {
|
|||||||
int input_index;
|
int input_index;
|
||||||
|
|
||||||
// Type and shape of the tensor to be written back.
|
// Type and shape of the tensor to be written back.
|
||||||
|
// The `shape` field has the same meaning as the Argument::shape field.
|
||||||
DataType type;
|
DataType type;
|
||||||
xla::Shape shape;
|
TensorShape shape;
|
||||||
|
|
||||||
// Was the value of the variable modified by the computation?
|
// Was the value of the variable modified by the computation?
|
||||||
// (Always true, unless `return_updated_values_for_all_resources` is true.)
|
// (Always true, unless `return_updated_values_for_all_resources` is true.)
|
||||||
@ -235,6 +244,19 @@ class XlaCompiler {
|
|||||||
// device is created, and can be used to create metadata objects
|
// device is created, and can be used to create metadata objects
|
||||||
// that can be accessed by XLA op kernels.
|
// that can be accessed by XLA op kernels.
|
||||||
std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
|
std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
|
||||||
|
|
||||||
|
// If not nullptr, this memory allocator can be used by the compiler for
|
||||||
|
// temporary allocations it might want to make during compilation.
|
||||||
|
//
|
||||||
|
// For example, the compiler may want to try out different algorithms and
|
||||||
|
// choose the fastest one, and it might run those algorithms over buffers
|
||||||
|
// created using this allocator.
|
||||||
|
//
|
||||||
|
// The compiler can function correctly without an explicit allocator given
|
||||||
|
// here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
|
||||||
|
// allocate most or all available memory on the device, leaving none for the
|
||||||
|
// compiler to access, unless it can use TensorFlow's allocator.
|
||||||
|
xla::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit XlaCompiler(Options options);
|
explicit XlaCompiler(Options options);
|
||||||
@ -253,11 +275,10 @@ class XlaCompiler {
|
|||||||
const std::vector<Argument>& args,
|
const std::vector<Argument>& args,
|
||||||
CompilationResult* result);
|
CompilationResult* result);
|
||||||
|
|
||||||
Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func,
|
// Returns the shape of the XLA parameter for an argument 'arg'.
|
||||||
const std::vector<DataType>& types,
|
// See the class comment for more details about the argument passing
|
||||||
const std::vector<TensorShape>& shapes,
|
// convention.
|
||||||
const std::vector<const XlaExpression*>& expressions,
|
static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
|
||||||
std::vector<Argument>* args);
|
|
||||||
|
|
||||||
// Retrieves the channel handle associated with `key`. Allocates
|
// Retrieves the channel handle associated with `key`. Allocates
|
||||||
// a new channel handle if none exists.
|
// a new channel handle if none exists.
|
||||||
|
@ -191,10 +191,10 @@ TEST_F(XlaCompilerTest, Simple) {
|
|||||||
std::vector<XlaCompiler::Argument> args(2);
|
std::vector<XlaCompiler::Argument> args(2);
|
||||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
args[0].type = DT_INT32;
|
args[0].type = DT_INT32;
|
||||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
args[0].shape = TensorShape({2});
|
||||||
args[1].kind = XlaCompiler::Argument::kParameter;
|
args[1].kind = XlaCompiler::Argument::kParameter;
|
||||||
args[1].type = DT_INT32;
|
args[1].type = DT_INT32;
|
||||||
args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
args[1].shape = TensorShape({2});
|
||||||
|
|
||||||
// Compiles the graph.
|
// Compiles the graph.
|
||||||
XlaCompiler compiler(DefaultOptions());
|
XlaCompiler compiler(DefaultOptions());
|
||||||
@ -242,10 +242,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
|
|||||||
std::vector<XlaCompiler::Argument> args(2);
|
std::vector<XlaCompiler::Argument> args(2);
|
||||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
args[0].type = DT_INT32;
|
args[0].type = DT_INT32;
|
||||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
args[0].shape = TensorShape({2});
|
||||||
args[1].kind = XlaCompiler::Argument::kParameter;
|
args[1].kind = XlaCompiler::Argument::kParameter;
|
||||||
args[1].type = DT_INT32;
|
args[1].type = DT_INT32;
|
||||||
args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
args[1].shape = TensorShape({2});
|
||||||
|
|
||||||
// Compiles the graph.
|
// Compiles the graph.
|
||||||
XlaCompiler compiler(DefaultOptions());
|
XlaCompiler compiler(DefaultOptions());
|
||||||
@ -281,7 +281,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
|||||||
std::vector<XlaCompiler::Argument> args(1);
|
std::vector<XlaCompiler::Argument> args(1);
|
||||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
args[0].type = DT_INT32;
|
args[0].type = DT_INT32;
|
||||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
args[0].shape = TensorShape({2});
|
||||||
|
|
||||||
XlaCompiler::Options options = DefaultOptions();
|
XlaCompiler::Options options = DefaultOptions();
|
||||||
XlaCompiler compiler(options);
|
XlaCompiler compiler(options);
|
||||||
@ -373,7 +373,7 @@ TEST_F(XlaCompilerTest, ResourceManager) {
|
|||||||
std::vector<XlaCompiler::Argument> args(1);
|
std::vector<XlaCompiler::Argument> args(1);
|
||||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
args[0].type = DT_INT32;
|
args[0].type = DT_INT32;
|
||||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
args[0].shape = TensorShape({2});
|
||||||
|
|
||||||
DummyResourceForTest* resource = new DummyResourceForTest();
|
DummyResourceForTest* resource = new DummyResourceForTest();
|
||||||
|
|
||||||
@ -420,7 +420,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
|
|||||||
std::vector<XlaCompiler::Argument> args(1);
|
std::vector<XlaCompiler::Argument> args(1);
|
||||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
args[0].type = DT_INT32;
|
args[0].type = DT_INT32;
|
||||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
args[0].shape = TensorShape({2});
|
||||||
|
|
||||||
// Compiles the graph.
|
// Compiles the graph.
|
||||||
auto options = DefaultOptions();
|
auto options = DefaultOptions();
|
||||||
@ -472,9 +472,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
|
|||||||
args[0].resource_kind = XlaResource::kTensorArray;
|
args[0].resource_kind = XlaResource::kTensorArray;
|
||||||
args[0].initialized = true;
|
args[0].initialized = true;
|
||||||
args[0].type = DT_INT32;
|
args[0].type = DT_INT32;
|
||||||
args[0].shape = xla::ShapeUtil::MakeTupleShape(
|
args[0].shape = TensorShape({});
|
||||||
{xla::ShapeUtil::MakeShape(xla::S32, {2}),
|
|
||||||
xla::ShapeUtil::MakeShape(xla::S32, {2})});
|
|
||||||
args[0].tensor_array_size = 2;
|
args[0].tensor_array_size = 2;
|
||||||
args[0].tensor_array_gradients = {"grad2"};
|
args[0].tensor_array_gradients = {"grad2"};
|
||||||
|
|
||||||
@ -540,9 +538,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
|
|||||||
args[0].resource_kind = XlaResource::kTensorArray;
|
args[0].resource_kind = XlaResource::kTensorArray;
|
||||||
args[0].initialized = true;
|
args[0].initialized = true;
|
||||||
args[0].type = DT_INT32;
|
args[0].type = DT_INT32;
|
||||||
args[0].shape = xla::ShapeUtil::MakeTupleShape(
|
args[0].shape = TensorShape({});
|
||||||
{xla::ShapeUtil::MakeShape(xla::S32, {2}),
|
|
||||||
xla::ShapeUtil::MakeShape(xla::S32, {2})});
|
|
||||||
args[0].tensor_array_size = 2;
|
args[0].tensor_array_size = 2;
|
||||||
args[0].tensor_array_gradients = {"grad1"};
|
args[0].tensor_array_gradients = {"grad1"};
|
||||||
|
|
||||||
@ -574,9 +570,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
|
|||||||
args[0].resource_kind = XlaResource::kTensorArray;
|
args[0].resource_kind = XlaResource::kTensorArray;
|
||||||
args[0].initialized = true;
|
args[0].initialized = true;
|
||||||
args[0].type = DT_INT32;
|
args[0].type = DT_INT32;
|
||||||
args[0].shape = xla::ShapeUtil::MakeTupleShape(
|
args[0].shape = TensorShape({});
|
||||||
{xla::ShapeUtil::MakeShape(xla::S32, {2}),
|
|
||||||
xla::ShapeUtil::MakeShape(xla::S32, {2})});
|
|
||||||
args[0].tensor_array_size = 2;
|
args[0].tensor_array_size = 2;
|
||||||
args[0].tensor_array_gradients = {"grad1"};
|
args[0].tensor_array_gradients = {"grad1"};
|
||||||
|
|
||||||
|
@ -103,12 +103,14 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
|||||||
|
|
||||||
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
|
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
|
||||||
|
|
||||||
Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
|
Status XlaContext::CreateResource(
|
||||||
string name, DataType type,
|
XlaResource::Kind kind, int arg_num, string name, DataType type,
|
||||||
const xla::ComputationDataHandle& handle,
|
TensorShape shape, const xla::ComputationDataHandle& handle,
|
||||||
|
int64 tensor_array_size, const std::set<string>& tensor_array_gradients,
|
||||||
XlaResource** resource) {
|
XlaResource** resource) {
|
||||||
resources_.emplace_back(
|
resources_.emplace_back(
|
||||||
new XlaResource(kind, arg_num, std::move(name), type, handle));
|
new XlaResource(kind, arg_num, std::move(name), type, std::move(shape),
|
||||||
|
handle, tensor_array_size, tensor_array_gradients));
|
||||||
*resource = resources_.back().get();
|
*resource = resources_.back().get();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -71,11 +71,15 @@ class XlaContext : public ResourceBase {
|
|||||||
Status AddConstRetval(int retval_index, DataType dtype,
|
Status AddConstRetval(int retval_index, DataType dtype,
|
||||||
const xla::Literal& literal);
|
const xla::Literal& literal);
|
||||||
|
|
||||||
// Creates a resource with resource `kind` and initial type `type` and
|
// Creates a resource with resource `kind` and initial value `handle`. `name`
|
||||||
// value `handle`. `name` is a descriptive name for use in error messages.
|
// is a descriptive name for use in error messages. See the `XlaResource`
|
||||||
|
// constructor for a description of the remaining arguments.
|
||||||
// Fails if the resource already exists.
|
// Fails if the resource already exists.
|
||||||
Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
|
Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
|
||||||
DataType type, const xla::ComputationDataHandle& handle,
|
DataType type, TensorShape shape,
|
||||||
|
const xla::ComputationDataHandle& handle,
|
||||||
|
int64 tensor_array_size,
|
||||||
|
const std::set<string>& tensor_array_gradients,
|
||||||
XlaResource** resource);
|
XlaResource** resource);
|
||||||
|
|
||||||
const std::vector<std::unique_ptr<XlaResource>>& resources() {
|
const std::vector<std::unique_ptr<XlaResource>>& resources() {
|
||||||
|
@ -286,7 +286,8 @@ Status XlaOpKernelContext::ConstantInputList(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status XlaOpKernelContext::ReadVariableInput(
|
Status XlaOpKernelContext::ReadVariableInput(
|
||||||
int index, xla::ComputationDataHandle* value) {
|
int index, DataType type, TensorShape* shape,
|
||||||
|
xla::ComputationDataHandle* value) {
|
||||||
const Tensor& tensor = context_->input(index);
|
const Tensor& tensor = context_->input(index);
|
||||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||||
XlaResource* variable = expression->resource();
|
XlaResource* variable = expression->resource();
|
||||||
@ -296,7 +297,15 @@ Status XlaOpKernelContext::ReadVariableInput(
|
|||||||
return errors::InvalidArgument("Read of uninitialized variable ",
|
return errors::InvalidArgument("Read of uninitialized variable ",
|
||||||
variable->name());
|
variable->name());
|
||||||
}
|
}
|
||||||
|
if (variable->type() != type) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Type mismatch for read of variable ", variable->name(), ". Expected ",
|
||||||
|
DataTypeString(type), "; got ", DataTypeString(variable->type()));
|
||||||
|
}
|
||||||
*value = variable->value();
|
*value = variable->value();
|
||||||
|
if (shape) {
|
||||||
|
*shape = variable->shape();
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,12 +321,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
|||||||
variable->name());
|
variable->name());
|
||||||
}
|
}
|
||||||
*type = variable->type();
|
*type = variable->type();
|
||||||
auto shape_or_status = builder()->GetShape(variable->value());
|
*shape = variable->shape();
|
||||||
if (!shape_or_status.ok()) {
|
|
||||||
return shape_or_status.status();
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,7 +409,17 @@ Status XlaOpKernelContext::AssignVariable(
|
|||||||
XlaResource* variable = expression->resource();
|
XlaResource* variable = expression->resource();
|
||||||
TF_RET_CHECK(variable != nullptr);
|
TF_RET_CHECK(variable != nullptr);
|
||||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||||
return variable->SetValue(type, handle);
|
|
||||||
|
auto shape_or_status = builder()->GetShape(handle);
|
||||||
|
if (!shape_or_status.ok()) {
|
||||||
|
return shape_or_status.status();
|
||||||
|
}
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
|
||||||
|
return variable->SetValue(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaCompiler* XlaOpKernelContext::compiler() const {
|
XlaCompiler* XlaOpKernelContext::compiler() const {
|
||||||
|
@ -164,11 +164,16 @@ class XlaOpKernelContext {
|
|||||||
TensorShape* shape) const;
|
TensorShape* shape) const;
|
||||||
|
|
||||||
// Reads the current value of the resouce variable referred to by input
|
// Reads the current value of the resouce variable referred to by input
|
||||||
// 'index'.
|
// 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the
|
||||||
Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
|
// variable. Returns an error if the variable has not been initialized, or if
|
||||||
|
// its type does not match `type`.
|
||||||
|
Status ReadVariableInput(int index, DataType type, TensorShape* shape,
|
||||||
|
xla::ComputationDataHandle* value);
|
||||||
|
|
||||||
// Assigns the value `handle` to the variable referenced by input
|
// Assigns the value `handle` to the variable referenced by input
|
||||||
// `input_index`. Marks the operator as having side effects.
|
// `input_index`. The variable must be of `type`. Returns an error if the
|
||||||
|
// variable has been initialized with a different type or with a
|
||||||
|
// different shape.
|
||||||
Status AssignVariable(int input_index, DataType type,
|
Status AssignVariable(int input_index, DataType type,
|
||||||
const xla::ComputationDataHandle& handle);
|
const xla::ComputationDataHandle& handle);
|
||||||
|
|
||||||
|
@ -25,51 +25,99 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
XlaResource::XlaResource(Kind kind, int arg_num, string name,
|
XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
|
||||||
DataType initial_type,
|
TensorShape shape,
|
||||||
const xla::ComputationDataHandle& initial_value)
|
const xla::ComputationDataHandle& initial_value,
|
||||||
|
int64 tensor_array_size,
|
||||||
|
const std::set<string>& tensor_array_gradients)
|
||||||
: kind_(kind),
|
: kind_(kind),
|
||||||
arg_num_(arg_num),
|
arg_num_(arg_num),
|
||||||
name_(std::move(name)),
|
name_(std::move(name)),
|
||||||
type_(initial_type),
|
type_(type),
|
||||||
|
shape_(std::move(shape)),
|
||||||
value_(initial_value),
|
value_(initial_value),
|
||||||
initial_value_(initial_value) {
|
initial_value_(initial_value),
|
||||||
|
tensor_array_size_(tensor_array_size) {
|
||||||
CHECK(kind_ != kInvalid);
|
CHECK(kind_ != kInvalid);
|
||||||
|
|
||||||
|
for (const string& gradient : tensor_array_gradients) {
|
||||||
|
tensor_array_gradients_[gradient].reset(
|
||||||
|
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
|
||||||
|
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
|
||||||
|
type_, shape_, xla::ComputationDataHandle(),
|
||||||
|
tensor_array_size_, /*tensor_array_gradients=*/{}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaResource::SetValue(DataType type,
|
Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
|
||||||
const xla::ComputationDataHandle& value) {
|
if (type == DT_INVALID) {
|
||||||
if (type_ == DT_INVALID && type == DT_INVALID) {
|
return errors::InvalidArgument("Attempted to set type of resource '", name_,
|
||||||
return errors::InvalidArgument("Attempted to initialized resource ", name_,
|
"'' to an invalid type");
|
||||||
" to an invalid type");
|
|
||||||
}
|
}
|
||||||
if (type_ != DT_INVALID && type_ != type) {
|
if (initialized() && type_ != type) {
|
||||||
return errors::InvalidArgument("Type of resource ", name_,
|
return errors::InvalidArgument("Type of resource ", name_,
|
||||||
" cannot be changed after initialization: "
|
" cannot be changed after initialization: "
|
||||||
"old type was ",
|
"old type was ",
|
||||||
DataTypeString(type_), ", new type is ",
|
DataTypeString(type_), ", new type is ",
|
||||||
DataTypeString(type));
|
DataTypeString(type));
|
||||||
}
|
}
|
||||||
|
if (initialized() && shape_ != shape) {
|
||||||
|
return errors::InvalidArgument("Shape of resource ", name_,
|
||||||
|
" cannot be changed after initialization: "
|
||||||
|
"old shape was ",
|
||||||
|
shape_.DebugString(), ", new shape is ",
|
||||||
|
shape.DebugString());
|
||||||
|
}
|
||||||
type_ = type;
|
type_ = type;
|
||||||
|
shape_ = shape;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status XlaResource::SetValue(const xla::ComputationDataHandle& value) {
|
||||||
|
if (type_ == DT_INVALID) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Resource '", name_,
|
||||||
|
"' must be initialized with a valid type before use.");
|
||||||
|
}
|
||||||
value_ = value;
|
value_ = value;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder,
|
Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) {
|
||||||
xla::Shape* shape) const {
|
if (type_ == DT_INVALID) {
|
||||||
auto shape_or_status = builder->GetShape(value_);
|
return errors::InvalidArgument(
|
||||||
if (!shape_or_status.ok()) {
|
"Resource '", name_,
|
||||||
return shape_or_status.status();
|
"' must be initialized with a valid type before use.");
|
||||||
|
}
|
||||||
|
switch (kind_) {
|
||||||
|
case kVariable: {
|
||||||
|
value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
|
||||||
|
shape_.dim_sizes());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kTensorArray: {
|
||||||
|
TensorShape ta_shape;
|
||||||
|
ta_shape.AddDim(tensor_array_size_);
|
||||||
|
ta_shape.AppendShape(shape_);
|
||||||
|
value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
|
||||||
|
ta_shape.dim_sizes());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kStack: {
|
||||||
|
TensorShape ta_shape;
|
||||||
|
ta_shape.AddDim(tensor_array_size_);
|
||||||
|
ta_shape.AppendShape(shape_);
|
||||||
|
value_ =
|
||||||
|
builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_),
|
||||||
|
ta_shape.dim_sizes()),
|
||||||
|
builder->ConstantR0<int32>(0)});
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
*shape = *shape_or_status.ValueOrDie();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status XlaResource::GetShape(xla::ComputationBuilder* builder,
|
case kInvalid:
|
||||||
TensorShape* shape) const {
|
default:
|
||||||
xla::Shape xla_shape;
|
LOG(FATAL) << "Invalid resource type";
|
||||||
TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape));
|
}
|
||||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,36 +130,20 @@ Status XlaResource::GetOrCreateTensorArrayGradient(
|
|||||||
std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
|
std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
|
||||||
if (!gradient) {
|
if (!gradient) {
|
||||||
TensorShape ta_shape;
|
TensorShape ta_shape;
|
||||||
TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape));
|
ta_shape.AddDim(tensor_array_size_);
|
||||||
|
ta_shape.AppendShape(shape_);
|
||||||
xla::ComputationDataHandle gradient_value = builder->Broadcast(
|
xla::ComputationDataHandle gradient_value = builder->Broadcast(
|
||||||
XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
|
XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
|
||||||
gradient.reset(
|
gradient.reset(
|
||||||
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
|
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
|
||||||
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
|
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
|
||||||
type_, gradient_value));
|
type_, shape_, gradient_value, tensor_array_size_,
|
||||||
gradient->tensor_array_size_ = tensor_array_size_;
|
/*tensor_array_gradients=*/{}));
|
||||||
}
|
}
|
||||||
*gradient_out = gradient.get();
|
*gradient_out = gradient.get();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaResource::PackedShape(xla::ComputationBuilder* builder,
|
|
||||||
xla::Shape* packed_shape) const {
|
|
||||||
if (tensor_array_gradients_.empty()) {
|
|
||||||
return GetXlaShape(builder, packed_shape);
|
|
||||||
}
|
|
||||||
TF_RET_CHECK(kind_ == kTensorArray);
|
|
||||||
std::vector<xla::Shape> elem_shapes(1 + tensor_array_gradients_.size());
|
|
||||||
int pos = 0;
|
|
||||||
TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++]));
|
|
||||||
for (const auto& gradient : tensor_array_gradients_) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
gradient.second->GetXlaShape(builder, &elem_shapes[pos++]));
|
|
||||||
}
|
|
||||||
*packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status XlaResource::Pack(xla::ComputationDataHandle* pack,
|
Status XlaResource::Pack(xla::ComputationDataHandle* pack,
|
||||||
xla::ComputationBuilder* builder) const {
|
xla::ComputationBuilder* builder) const {
|
||||||
if (tensor_array_gradients_.empty()) {
|
if (tensor_array_gradients_.empty()) {
|
||||||
@ -130,27 +162,32 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack,
|
|||||||
|
|
||||||
Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
|
Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
|
||||||
const xla::ComputationDataHandle& pack,
|
const xla::ComputationDataHandle& pack,
|
||||||
bool reset_initial_values,
|
|
||||||
xla::ComputationBuilder* builder) {
|
xla::ComputationBuilder* builder) {
|
||||||
if (gradient_sources.empty()) {
|
if (gradient_sources.empty()) {
|
||||||
|
if (!initialized()) {
|
||||||
|
initial_value_ = pack;
|
||||||
|
}
|
||||||
value_ = pack;
|
value_ = pack;
|
||||||
} else {
|
} else {
|
||||||
TF_RET_CHECK(kind_ == kTensorArray);
|
TF_RET_CHECK(kind_ == kTensorArray);
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
value_ = builder->GetTupleElement(pack, pos++);
|
auto v = builder->GetTupleElement(pack, pos++);
|
||||||
|
if (!initialized()) {
|
||||||
|
initial_value_ = v;
|
||||||
|
}
|
||||||
|
value_ = v;
|
||||||
|
|
||||||
for (const auto& source : gradient_sources) {
|
for (const auto& source : gradient_sources) {
|
||||||
XlaResource* gradient;
|
XlaResource* gradient;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
GetOrCreateTensorArrayGradient(source, builder, &gradient));
|
GetOrCreateTensorArrayGradient(source, builder, &gradient));
|
||||||
gradient->value_ = builder->GetTupleElement(pack, pos++);
|
auto v = builder->GetTupleElement(pack, pos++);
|
||||||
if (reset_initial_values) {
|
if (!gradient->initialized()) {
|
||||||
gradient->initial_value_ = gradient->value_;
|
gradient->initial_value_ = v;
|
||||||
}
|
}
|
||||||
|
gradient->value_ = v;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (reset_initial_values) {
|
|
||||||
initial_value_ = value_;
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,8 +36,11 @@ class XlaResource {
|
|||||||
kStack,
|
kStack,
|
||||||
};
|
};
|
||||||
|
|
||||||
XlaResource(Kind kind, int arg_num, string name, DataType initial_type,
|
XlaResource(Kind kind, int arg_num, string name, DataType type,
|
||||||
const xla::ComputationDataHandle& initial_value);
|
TensorShape shape,
|
||||||
|
const xla::ComputationDataHandle& initial_value,
|
||||||
|
int64 tensor_array_size,
|
||||||
|
const std::set<string>& tensor_array_gradients);
|
||||||
|
|
||||||
XlaResource(const XlaResource&) = delete;
|
XlaResource(const XlaResource&) = delete;
|
||||||
XlaResource(XlaResource&&) = delete;
|
XlaResource(XlaResource&&) = delete;
|
||||||
@ -60,6 +63,12 @@ class XlaResource {
|
|||||||
// a resource is first initialized we do not yet know its type, so we keep
|
// a resource is first initialized we do not yet know its type, so we keep
|
||||||
// track of its type dynamically.
|
// track of its type dynamically.
|
||||||
DataType type() const { return type_; }
|
DataType type() const { return type_; }
|
||||||
|
|
||||||
|
// Shape of the resource. For an uninitialized resource, this is ignored.
|
||||||
|
// For a Variable, this is the shape of the value. For a TensorArray or Stack
|
||||||
|
// this is the shape of each entry in the TensorArray/Stack.
|
||||||
|
const TensorShape& shape() const { return shape_; }
|
||||||
|
|
||||||
const xla::ComputationDataHandle& value() const { return value_; }
|
const xla::ComputationDataHandle& value() const { return value_; }
|
||||||
|
|
||||||
// Value of the resource at computation entry. Used to detect which
|
// Value of the resource at computation entry. Used to detect which
|
||||||
@ -68,17 +77,19 @@ class XlaResource {
|
|||||||
return initial_value_;
|
return initial_value_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A variable is initialized if it has a value.
|
||||||
bool initialized() const { return value_.handle() > 0; }
|
bool initialized() const { return value_.handle() > 0; }
|
||||||
|
|
||||||
// Sets the current type/value of the resource.
|
// Sets the type and shape of the resource. The type and shape of a resource
|
||||||
Status SetValue(DataType type, const xla::ComputationDataHandle& value);
|
// must not change once the variable has been initialized.
|
||||||
|
Status SetTypeAndShape(DataType type, const TensorShape& shape);
|
||||||
|
|
||||||
// Returns the shape of the resource as an xla::Shape.
|
// Sets the current value of the resource. Returns an error if the type is not
|
||||||
Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const;
|
// set to a valid value.
|
||||||
|
Status SetValue(const xla::ComputationDataHandle& value);
|
||||||
|
|
||||||
// Returns the shape of the resource as an TensorShape. Fails if the shape is
|
// Sets the current value of the resource to an all-zero value.
|
||||||
// not representable as a TensorShape.
|
Status SetZeroValue(xla::ComputationBuilder* builder);
|
||||||
Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const;
|
|
||||||
|
|
||||||
// Looks up the gradient for `source`, or creates it if it does not already
|
// Looks up the gradient for `source`, or creates it if it does not already
|
||||||
// exist. The call target must be an initialized TensorArray resource. A
|
// exist. The call target must be an initialized TensorArray resource. A
|
||||||
@ -96,10 +107,6 @@ class XlaResource {
|
|||||||
Status Pack(xla::ComputationDataHandle* pack,
|
Status Pack(xla::ComputationDataHandle* pack,
|
||||||
xla::ComputationBuilder* builder) const;
|
xla::ComputationBuilder* builder) const;
|
||||||
|
|
||||||
// Returns the shape of the `pack` value computed by `Pack()`.
|
|
||||||
Status PackedShape(xla::ComputationBuilder* builder,
|
|
||||||
xla::Shape* packed_shape) const;
|
|
||||||
|
|
||||||
// Updates the resource with values from `pack`. If `gradient_sources` is
|
// Updates the resource with values from `pack`. If `gradient_sources` is
|
||||||
// non-empty, treats `pack` as a tuple that represents a TensorArray and
|
// non-empty, treats `pack` as a tuple that represents a TensorArray and
|
||||||
// its gradients, and unpacks and updates the gradient resources.
|
// its gradients, and unpacks and updates the gradient resources.
|
||||||
@ -108,14 +115,14 @@ class XlaResource {
|
|||||||
// Opposite of Pack().
|
// Opposite of Pack().
|
||||||
Status SetFromPack(const std::set<string>& gradient_sources,
|
Status SetFromPack(const std::set<string>& gradient_sources,
|
||||||
const xla::ComputationDataHandle& pack,
|
const xla::ComputationDataHandle& pack,
|
||||||
bool reset_initial_values,
|
|
||||||
xla::ComputationBuilder* builder);
|
xla::ComputationBuilder* builder);
|
||||||
|
|
||||||
// TensorArray-specific fields
|
// TensorArray and Stack specific fields
|
||||||
|
|
||||||
// 'tensor_array_size' stores the expected size of the TensorArray or Stack.
|
// 'tensor_array_size' stores the expected size of the TensorArray or Stack.
|
||||||
// We need to store this since sometimes TensorArrays must be initialized
|
// We need to store this since sometimes TensorArrays must be initialized
|
||||||
// lazily since we do not know the element shape at construction time.
|
// lazily since we do not know the element shape at construction time.
|
||||||
|
// Used by both TensorArrays and Stacks.
|
||||||
int64 tensor_array_size() const { return tensor_array_size_; }
|
int64 tensor_array_size() const { return tensor_array_size_; }
|
||||||
void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
|
void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
|
||||||
|
|
||||||
@ -136,6 +143,7 @@ class XlaResource {
|
|||||||
const string name_;
|
const string name_;
|
||||||
|
|
||||||
DataType type_;
|
DataType type_;
|
||||||
|
TensorShape shape_;
|
||||||
xla::ComputationDataHandle value_;
|
xla::ComputationDataHandle value_;
|
||||||
xla::ComputationDataHandle initial_value_;
|
xla::ComputationDataHandle initial_value_;
|
||||||
|
|
||||||
|
@ -88,7 +88,6 @@ cc_library(
|
|||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -80,6 +80,18 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "executable_build_options",
|
||||||
|
srcs = ["executable_build_options.cc"],
|
||||||
|
hdrs = ["executable_build_options.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "local_client",
|
name = "local_client",
|
||||||
srcs = ["local_client.cc"],
|
srcs = ["local_client.cc"],
|
||||||
@ -87,6 +99,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":client",
|
":client",
|
||||||
":computation",
|
":computation",
|
||||||
|
":executable_build_options",
|
||||||
"//tensorflow/compiler/xla:executable_run_options",
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
79
tensorflow/compiler/xla/client/executable_build_options.cc
Normal file
79
tensorflow/compiler/xla/client/executable_build_options.cc
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator(
|
||||||
|
DeviceMemoryAllocator* allocator) {
|
||||||
|
device_allocator_ = allocator;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
|
||||||
|
return device_allocator_;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
|
||||||
|
int device_ordinal) {
|
||||||
|
CHECK_GE(device_ordinal, 0);
|
||||||
|
device_ordinal_ = device_ordinal;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; }
|
||||||
|
|
||||||
|
ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout(
|
||||||
|
const Shape& shape_with_layout) {
|
||||||
|
result_layout_set_ = true;
|
||||||
|
result_layout_ = shape_with_layout;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Shape* ExecutableBuildOptions::result_layout() const {
|
||||||
|
return result_layout_set_ ? &result_layout_ : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
string ExecutableBuildOptions::ToString() const {
|
||||||
|
string result_layout = "nullopt";
|
||||||
|
if (result_layout_set_) {
|
||||||
|
result_layout = ShapeUtil::HumanStringWithLayout(result_layout_);
|
||||||
|
}
|
||||||
|
string generate_hlo_graph = "nullopt";
|
||||||
|
if (generate_hlo_graph_.has_value()) {
|
||||||
|
generate_hlo_graph = generate_hlo_graph_.value();
|
||||||
|
}
|
||||||
|
return tensorflow::strings::Printf(
|
||||||
|
"ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, "
|
||||||
|
"generate_hlo_graph=%s}",
|
||||||
|
device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph(
|
||||||
|
string regex) {
|
||||||
|
generate_hlo_graph_ = std::move(regex);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tensorflow::gtl::optional<string>&
|
||||||
|
ExecutableBuildOptions::generate_hlo_graph() const {
|
||||||
|
return generate_hlo_graph_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
74
tensorflow/compiler/xla/client/executable_build_options.h
Normal file
74
tensorflow/compiler/xla/client/executable_build_options.h
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/optional.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Class containing options for building an LocalExecutable with
|
||||||
|
// LocalClient::Compile.
|
||||||
|
class ExecutableBuildOptions {
|
||||||
|
public:
|
||||||
|
// If set, this is the device to build the computation for. Valid
|
||||||
|
// device_ordinal values are: 0 to # of devices - 1. These values are
|
||||||
|
// identical to the device ordinal values used by StreamExecutor. The built
|
||||||
|
// executable will be executable on any device equivalent to the specified
|
||||||
|
// device as determined by Backend::devices_equivalent(). A value of -1
|
||||||
|
// indicates this option has not been set.
|
||||||
|
ExecutableBuildOptions& set_device_ordinal(int device_ordinal);
|
||||||
|
int device_ordinal() const;
|
||||||
|
|
||||||
|
// If set, this specifies the layout of the result of the computation. If not
|
||||||
|
// set, the service will chose the layout of the result. A Shape is used to
|
||||||
|
// store the layout to accommodate tuple result shapes. A value of nullptr
|
||||||
|
// indicates the option has not been set.
|
||||||
|
ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
|
||||||
|
const Shape* result_layout() const;
|
||||||
|
|
||||||
|
// If set, this specifies an allocator that can be used to allocate temporary
|
||||||
|
// space on the device during compilation. For example, the compiler might
|
||||||
|
// want to run various algorithms on the device and pick the fastest one -- it
|
||||||
|
// might allocate buffers for use by these algorithms using this allocator.
|
||||||
|
//
|
||||||
|
// This does not need to be the same as the DeviceMemoryAllocator passed when
|
||||||
|
// running the executable.
|
||||||
|
ExecutableBuildOptions& set_device_allocator(
|
||||||
|
DeviceMemoryAllocator* allocator);
|
||||||
|
DeviceMemoryAllocator* device_allocator() const;
|
||||||
|
|
||||||
|
// If set, specifies a regexp of HLO graphs to dump (as in DebugOptions).
|
||||||
|
ExecutableBuildOptions& set_generate_hlo_graph(string regex);
|
||||||
|
const tensorflow::gtl::optional<string>& generate_hlo_graph() const;
|
||||||
|
|
||||||
|
// Returns a string representation of the build options, suitable for
|
||||||
|
// debugging.
|
||||||
|
string ToString() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int device_ordinal_ = -1;
|
||||||
|
Shape result_layout_;
|
||||||
|
bool result_layout_set_ = false;
|
||||||
|
tensorflow::gtl::optional<string> generate_hlo_graph_;
|
||||||
|
DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
|
@ -30,25 +30,6 @@ using xla::source_map_util::InvalidParameterArgument;
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
|
|
||||||
int device_ordinal) {
|
|
||||||
device_ordinal_ = device_ordinal;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; }
|
|
||||||
|
|
||||||
ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout(
|
|
||||||
const Shape& shape_with_layout) {
|
|
||||||
result_layout_set_ = true;
|
|
||||||
result_layout_ = shape_with_layout;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Shape* ExecutableBuildOptions::result_layout() const {
|
|
||||||
return result_layout_set_ ? &result_layout_ : nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
|
StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
|
||||||
Backend* backend) {
|
Backend* backend) {
|
||||||
@ -60,16 +41,18 @@ StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
|
LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
|
||||||
Backend* backend, int device_ordinal,
|
Backend* backend,
|
||||||
const ExecutableBuildOptions& build_options)
|
ExecutableBuildOptions build_options)
|
||||||
: executable_(std::move(executable)),
|
: executable_(std::move(executable)),
|
||||||
backend_(backend),
|
backend_(backend),
|
||||||
build_device_ordinal_(device_ordinal),
|
build_options_(std::move(build_options)) {
|
||||||
build_options_(build_options) {}
|
CHECK_GE(build_options_.device_ordinal(), 0)
|
||||||
|
<< "Must have a valid device ordinal that the executable was built for.";
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
||||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||||
const ExecutableRunOptions& options, const Backend& backend) {
|
const ExecutableRunOptions& run_options, const Backend& backend) {
|
||||||
const ComputationLayout& computation_layout =
|
const ComputationLayout& computation_layout =
|
||||||
executable_->module_config().entry_computation_layout();
|
executable_->module_config().entry_computation_layout();
|
||||||
|
|
||||||
@ -93,14 +76,14 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.stream() != nullptr) {
|
if (run_options.stream() != nullptr) {
|
||||||
if (!options.stream()->ok()) {
|
if (!run_options.stream()->ok()) {
|
||||||
return InvalidArgument("stream is uninitialized or in an error state");
|
return InvalidArgument("stream is uninitialized or in an error state");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check stream matches service platform.
|
// Check stream matches service platform.
|
||||||
const se::Platform* stream_platform =
|
const se::Platform* stream_platform =
|
||||||
options.stream()->parent()->platform();
|
run_options.stream()->parent()->platform();
|
||||||
if (stream_platform != backend_->platform()) {
|
if (stream_platform != backend_->platform()) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"stream is for platform %s, but service targets platform %s",
|
"stream is for platform %s, but service targets platform %s",
|
||||||
@ -110,7 +93,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
|||||||
|
|
||||||
// Cannot specify device_ordinal with a stream. The stream determines these
|
// Cannot specify device_ordinal with a stream. The stream determines these
|
||||||
// values.
|
// values.
|
||||||
if (options.device_ordinal() != -1) {
|
if (run_options.device_ordinal() != -1) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"cannot set both device ordinal and stream options in "
|
"cannot set both device ordinal and stream options in "
|
||||||
"ExecutableRunOptions; the stream determines the device ordinal");
|
"ExecutableRunOptions; the stream determines the device ordinal");
|
||||||
@ -119,34 +102,34 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
|||||||
|
|
||||||
// Verify that the device the executable was built for is equivalent to the
|
// Verify that the device the executable was built for is equivalent to the
|
||||||
// device it will run on.
|
// device it will run on.
|
||||||
int run_device_ordinal = options.device_ordinal() == -1
|
int run_device_ordinal = run_options.device_ordinal() == -1
|
||||||
? backend_->default_device_ordinal()
|
? backend_->default_device_ordinal()
|
||||||
: options.device_ordinal();
|
: run_options.device_ordinal();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(bool devices_equivalent,
|
||||||
bool devices_equivalent,
|
backend_->devices_equivalent(
|
||||||
backend_->devices_equivalent(run_device_ordinal, build_device_ordinal_));
|
run_device_ordinal, build_options_.device_ordinal()));
|
||||||
if (!devices_equivalent) {
|
if (!devices_equivalent) {
|
||||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
|
TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
|
||||||
backend_->stream_executor(run_device_ordinal));
|
backend_->stream_executor(run_device_ordinal));
|
||||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
|
TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
|
||||||
backend_->stream_executor(build_device_ordinal_));
|
backend_->stream_executor(build_device_ordinal()));
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"executable is built for device %s of type \"%s\"; cannot run it on "
|
"executable is built for device %s of type \"%s\"; cannot run it on "
|
||||||
"device %s of type \"%s\"",
|
"device %s of type \"%s\"",
|
||||||
backend_->device_name(build_device_ordinal_).c_str(),
|
backend_->device_name(build_device_ordinal()).c_str(),
|
||||||
build_executor->GetDeviceDescription().name().c_str(),
|
build_executor->GetDeviceDescription().name().c_str(),
|
||||||
backend_->device_name(run_device_ordinal).c_str(),
|
backend_->device_name(run_device_ordinal).c_str(),
|
||||||
run_executor->GetDeviceDescription().name().c_str());
|
run_executor->GetDeviceDescription().name().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!options.allocator()) {
|
if (!run_options.allocator()) {
|
||||||
return InvalidArgument("an allocator must be provided to ExecuteLocally");
|
return InvalidArgument("an allocator must be provided to ExecuteLocally");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.allocator()->platform() != backend.platform()) {
|
if (run_options.allocator()->platform() != backend.platform()) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"allocator platform (%s) does not match service platform (%s)",
|
"allocator platform (%s) does not match service platform (%s)",
|
||||||
options.allocator()->platform()->Name().c_str(),
|
run_options.allocator()->platform()->Name().c_str(),
|
||||||
backend.platform()->Name().c_str());
|
backend.platform()->Name().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,23 +138,22 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
|
StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
|
||||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||||
const ExecutableRunOptions& options) {
|
ExecutableRunOptions run_options) {
|
||||||
TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_));
|
TF_RETURN_IF_ERROR(
|
||||||
|
ValidateExecutionOptions(arguments, run_options, *backend_));
|
||||||
ExecutableRunOptions actual_options = options;
|
|
||||||
|
|
||||||
Backend::StreamPtr stream;
|
Backend::StreamPtr stream;
|
||||||
if (options.stream() == nullptr) {
|
if (run_options.stream() == nullptr) {
|
||||||
// NB! The lifetime of `stream` needs to match the lifetime of
|
// NB! The lifetime of `stream` needs to match the lifetime of
|
||||||
// `actual_options` (otherwise we will end up using a returned stream in
|
// `actual_options` (otherwise we will end up using a returned stream in
|
||||||
// ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
|
// ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
|
||||||
// scope.
|
// scope.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
stream, BorrowStreamForDevice(options.device_ordinal(), backend_));
|
stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
|
||||||
actual_options.set_stream(stream.get());
|
run_options.set_stream(stream.get());
|
||||||
}
|
}
|
||||||
if (options.allocator() == nullptr) {
|
if (run_options.allocator() == nullptr) {
|
||||||
actual_options.set_allocator(backend_->memory_allocator());
|
run_options.set_allocator(backend_->memory_allocator());
|
||||||
}
|
}
|
||||||
|
|
||||||
// For local client execution on CPU backends:
|
// For local client execution on CPU backends:
|
||||||
@ -180,7 +162,7 @@ StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
|
|||||||
// *) The thread pool used for XLA CPU ops is from
|
// *) The thread pool used for XLA CPU ops is from
|
||||||
// backend_->eigen_intra_op_thread_pool().
|
// backend_->eigen_intra_op_thread_pool().
|
||||||
ServiceExecutableRunOptions service_options(
|
ServiceExecutableRunOptions service_options(
|
||||||
actual_options, backend_->StreamBorrower(),
|
run_options, backend_->StreamBorrower(),
|
||||||
backend_->eigen_intra_op_thread_pool());
|
backend_->eigen_intra_op_thread_pool());
|
||||||
|
|
||||||
if (executable_->dumping()) {
|
if (executable_->dumping()) {
|
||||||
@ -189,9 +171,8 @@ StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<ShapedBuffer> result,
|
std::unique_ptr<ShapedBuffer> result,
|
||||||
executable_->ExecuteOnStreamWrapper(
|
executable_->ExecuteOnStreamWrapper(
|
||||||
&service_options, options.execution_profile(), arguments));
|
&service_options, run_options.execution_profile(), arguments));
|
||||||
return ScopedShapedBuffer::MakeScoped(result.get(),
|
return ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator());
|
||||||
actual_options.allocator());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::ExecuteAndDump(
|
StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::ExecuteAndDump(
|
||||||
@ -267,16 +248,19 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
|
|||||||
const Computation& computation,
|
const Computation& computation,
|
||||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||||
const ExecutableBuildOptions& options) {
|
const ExecutableBuildOptions& options) {
|
||||||
int device_ordinal = options.device_ordinal() == -1
|
ExecutableBuildOptions updated_options = options;
|
||||||
? default_device_ordinal()
|
if (options.device_ordinal() == -1) {
|
||||||
: options.device_ordinal();
|
updated_options.set_device_ordinal(default_device_ordinal());
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
VLOG(3) << "Set device ordinal to default value of: "
|
||||||
local_service_->CompileExecutable(
|
<< updated_options.device_ordinal();
|
||||||
computation.handle(), argument_layouts,
|
}
|
||||||
options.result_layout(), device_ordinal));
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<Executable> executable,
|
||||||
|
local_service_->CompileExecutable(computation.handle(), argument_layouts,
|
||||||
|
updated_options));
|
||||||
return WrapUnique(new LocalExecutable(std::move(executable),
|
return WrapUnique(new LocalExecutable(std::move(executable),
|
||||||
local_service_->mutable_backend(),
|
local_service_->mutable_backend(),
|
||||||
device_ordinal, options));
|
updated_options));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>>
|
StatusOr<std::unique_ptr<ScopedShapedBuffer>>
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/client.h"
|
#include "tensorflow/compiler/xla/client/client.h"
|
||||||
#include "tensorflow/compiler/xla/client/computation.h"
|
#include "tensorflow/compiler/xla/client/computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||||
@ -33,39 +34,13 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Class containing options for building an LocalExecutable with
|
|
||||||
// LocalClient::Compile.
|
|
||||||
class ExecutableBuildOptions {
|
|
||||||
public:
|
|
||||||
// If set, this is the device to build the computation for. Valid
|
|
||||||
// device_ordinal values are: 0 to # of devices - 1. These values are
|
|
||||||
// identical to the device ordinal values used by StreamExecutor. The built
|
|
||||||
// executable will be executable on any device equivalent to the specified
|
|
||||||
// device as determined by Backend::devices_equivalent(). A value of -1
|
|
||||||
// indicates this option has not been set.
|
|
||||||
ExecutableBuildOptions& set_device_ordinal(int device_ordinal);
|
|
||||||
int device_ordinal() const;
|
|
||||||
|
|
||||||
// If set, this specifies the layout of the result of the computation. If not
|
|
||||||
// set, the service will chose the layout of the result. A Shape is used to
|
|
||||||
// store the layout to accommodate tuple result shapes. A value of nullptr
|
|
||||||
// indicates the option has not been set.
|
|
||||||
ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
|
|
||||||
const Shape* result_layout() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
int device_ordinal_ = -1;
|
|
||||||
Shape result_layout_;
|
|
||||||
bool result_layout_set_ = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
class LocalExecutable {
|
class LocalExecutable {
|
||||||
public:
|
public:
|
||||||
// Run the compiled computation with the given arguments and options and
|
// Run the compiled computation with the given arguments and options and
|
||||||
// return the result.
|
// return the result.
|
||||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>> Run(
|
StatusOr<std::unique_ptr<ScopedShapedBuffer>> Run(
|
||||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||||
const ExecutableRunOptions& options);
|
ExecutableRunOptions run_options);
|
||||||
|
|
||||||
// Return the layout (contained in a shape) of the result produced by the
|
// Return the layout (contained in a shape) of the result produced by the
|
||||||
// computation.
|
// computation.
|
||||||
@ -88,8 +63,7 @@ class LocalExecutable {
|
|||||||
|
|
||||||
// Constructor invoked by LocalClient.
|
// Constructor invoked by LocalClient.
|
||||||
LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
|
LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
|
||||||
int device_ordinal,
|
ExecutableBuildOptions build_options);
|
||||||
const ExecutableBuildOptions& build_options);
|
|
||||||
|
|
||||||
// Validates that the given arguments and options satisfy various constraints
|
// Validates that the given arguments and options satisfy various constraints
|
||||||
// of the computation.
|
// of the computation.
|
||||||
@ -117,19 +91,19 @@ class LocalExecutable {
|
|||||||
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
|
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
|
||||||
const ShapedBuffer& shaped_buffer);
|
const ShapedBuffer& shaped_buffer);
|
||||||
|
|
||||||
|
// The ordinal of the device which this executable was compiled for. The
|
||||||
|
// executable can run on all equivalent devices (as determined by
|
||||||
|
// Backend::devices_equivalent).
|
||||||
|
int build_device_ordinal() const { return build_options_.device_ordinal(); }
|
||||||
|
|
||||||
// Compiled computation.
|
// Compiled computation.
|
||||||
std::unique_ptr<Executable> executable_;
|
std::unique_ptr<Executable> executable_;
|
||||||
|
|
||||||
// Execution backend.
|
// Execution backend.
|
||||||
Backend* backend_;
|
Backend* backend_ = nullptr;
|
||||||
|
|
||||||
// The ordinal of the device which this executable was compiled for. The
|
|
||||||
// executable can run on all equivalent devices (as determined by
|
|
||||||
// Backend::devices_equivalent).
|
|
||||||
int build_device_ordinal_;
|
|
||||||
|
|
||||||
// Options used to build the executable.
|
// Options used to build the executable.
|
||||||
const ExecutableBuildOptions& build_options_;
|
const ExecutableBuildOptions build_options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// An XLA Client specialization for use when the client and service run in
|
// An XLA Client specialization for use when the client and service run in
|
||||||
|
@ -221,13 +221,19 @@ void AllocateFlags() {
|
|||||||
flag_values->xla_gpu_disable_multi_streaming(),
|
flag_values->xla_gpu_disable_multi_streaming(),
|
||||||
"If true, multi-streaming in the GPU backend is disabled."),
|
"If true, multi-streaming in the GPU backend is disabled."),
|
||||||
tensorflow::Flag(
|
tensorflow::Flag(
|
||||||
"xla_dump_hlo_proto_to", flag_values->mutable_xla_dump_hlo_proto_to(),
|
"xla_dump_optimized_hlo_proto_to",
|
||||||
"Dump compilation artifacts as proto binary into this directory."),
|
flag_values->mutable_xla_dump_optimized_hlo_proto_to(),
|
||||||
|
"Dump Hlo after all hlo passes are executed as proto binary into "
|
||||||
|
"this directory."),
|
||||||
tensorflow::Flag(
|
tensorflow::Flag(
|
||||||
"xla_dump_prepass_hlo_proto_to",
|
"xla_dump_unoptimized_hlo_proto_to",
|
||||||
flag_values->mutable_xla_dump_prepass_hlo_proto_to(),
|
flag_values->mutable_xla_dump_unoptimized_hlo_proto_to(),
|
||||||
"Dump compilation artifacts, before hlo passes are executed, as "
|
"Dump HLO before any hlo passes are executed as proto binary into "
|
||||||
"proto binary into this directory."),
|
"this directory."),
|
||||||
|
tensorflow::Flag("xla_dump_per_pass_hlo_proto_to",
|
||||||
|
flag_values->mutable_xla_dump_per_pass_hlo_proto_to(),
|
||||||
|
"Dump HLO after each pass as an HloProto in binary file "
|
||||||
|
"format into this directory."),
|
||||||
tensorflow::Flag(
|
tensorflow::Flag(
|
||||||
"xla_test_all_output_layouts",
|
"xla_test_all_output_layouts",
|
||||||
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
|
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
|
||||||
|
@ -486,6 +486,7 @@ class Literal {
|
|||||||
std::vector<std::unique_ptr<Literal>> elements);
|
std::vector<std::unique_ptr<Literal>> elements);
|
||||||
|
|
||||||
// Returns a string representation of the literal value.
|
// Returns a string representation of the literal value.
|
||||||
|
// Warning: this function can take minutes for multi-million element Literals.
|
||||||
string ToString(bool print_layout = false) const;
|
string ToString(bool print_layout = false) const;
|
||||||
|
|
||||||
// Invokes the "per cell" callback for each element in the provided
|
// Invokes the "per cell" callback for each element in the provided
|
||||||
|
@ -49,6 +49,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
"//tensorflow/compiler/xla/client:computation_builder",
|
"//tensorflow/compiler/xla/client:computation_builder",
|
||||||
|
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/core:framework_lite",
|
"//tensorflow/core:framework_lite",
|
||||||
|
@ -98,15 +98,25 @@ const std::unique_ptr<ScopedShapedBuffer>& LocalShapedBuffer::shaped_buffer()
|
|||||||
return shaped_buffer_;
|
return shaped_buffer_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static StatusOr<std::unique_ptr<ScopedShapedBuffer>> ToBuffer(
|
||||||
|
LocalClient* client, int device_ordinal, const Literal& arg) {
|
||||||
|
return client->LiteralToShapedBuffer(arg, device_ordinal,
|
||||||
|
client->backend().memory_allocator());
|
||||||
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
LocalShapedBuffer* LocalShapedBuffer::FromLiteral(const Literal& argument) {
|
LocalShapedBuffer* LocalShapedBuffer::FromLiteral(
|
||||||
|
const Literal& argument,
|
||||||
|
const tensorflow::gtl::optional<Shape>& shape_with_layout) {
|
||||||
LocalClient* client = GetOrCreateLocalClient();
|
LocalClient* client = GetOrCreateLocalClient();
|
||||||
std::unique_ptr<ScopedShapedBuffer> buf =
|
std::unique_ptr<ScopedShapedBuffer> buf;
|
||||||
client
|
if (shape_with_layout) {
|
||||||
->LiteralToShapedBuffer(argument,
|
std::unique_ptr<Literal> relaid =
|
||||||
/*device_ordinal=*/0,
|
argument.Relayout(shape_with_layout.value());
|
||||||
client->backend().memory_allocator())
|
buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie();
|
||||||
.ConsumeValueOrDie();
|
} else {
|
||||||
|
buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie();
|
||||||
|
}
|
||||||
return new LocalShapedBuffer(std::move(buf));
|
return new LocalShapedBuffer(std::move(buf));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,7 +130,8 @@ CompiledLocalComputation::CompiledLocalComputation(
|
|||||||
: executable_(std::move(executable)) {}
|
: executable_(std::move(executable)) {}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
||||||
const std::vector<Literal>& arguments) {
|
const std::vector<Literal>& arguments,
|
||||||
|
const std::vector<tensorflow::gtl::optional<Shape>>& shapes_with_layout) {
|
||||||
LocalClient* client = GetOrCreateLocalClient();
|
LocalClient* client = GetOrCreateLocalClient();
|
||||||
|
|
||||||
VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas.";
|
VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas.";
|
||||||
@ -133,7 +144,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
|||||||
GetReplicaCount());
|
GetReplicaCount());
|
||||||
|
|
||||||
for (int replica = 0; replica < GetReplicaCount(); ++replica) {
|
for (int replica = 0; replica < GetReplicaCount(); ++replica) {
|
||||||
pool.Schedule([this, client, replica, &arguments, &results] {
|
pool.Schedule([this, client, replica, &arguments, &shapes_with_layout,
|
||||||
|
&results] {
|
||||||
StatusOr<int> device_ordinal_status =
|
StatusOr<int> device_ordinal_status =
|
||||||
client->ReplicaNumberToDeviceOrdinal(replica);
|
client->ReplicaNumberToDeviceOrdinal(replica);
|
||||||
if (!device_ordinal_status.ok()) {
|
if (!device_ordinal_status.ok()) {
|
||||||
@ -144,18 +156,28 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
|||||||
VLOG(3) << "Replica " << replica
|
VLOG(3) << "Replica " << replica
|
||||||
<< " mapped to device ordinal for execution: "
|
<< " mapped to device ordinal for execution: "
|
||||||
<< device_ordinal;
|
<< device_ordinal;
|
||||||
|
|
||||||
// Transfer arguments in
|
// Transfer arguments in
|
||||||
std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
|
std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
|
||||||
scoped_buffers.reserve(arguments.size());
|
scoped_buffers.reserve(arguments.size());
|
||||||
for (const Literal& argument : arguments) {
|
for (int i = 0; i < arguments.size(); ++i) {
|
||||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed =
|
const Literal& argument = arguments[i];
|
||||||
client->LiteralToShapedBuffer(
|
const tensorflow::gtl::optional<Shape>& shape_with_layout =
|
||||||
argument, device_ordinal,
|
shapes_with_layout[i];
|
||||||
client->backend().memory_allocator());
|
|
||||||
|
StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed;
|
||||||
|
if (shape_with_layout) {
|
||||||
|
std::unique_ptr<Literal> relaid =
|
||||||
|
argument.Relayout(shape_with_layout.value());
|
||||||
|
pushed = ToBuffer(client, device_ordinal, *relaid);
|
||||||
|
} else {
|
||||||
|
pushed = ToBuffer(client, device_ordinal, argument);
|
||||||
|
}
|
||||||
if (!pushed.ok()) {
|
if (!pushed.ok()) {
|
||||||
results[replica] = pushed.status();
|
results[replica] = pushed.status();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
scoped_buffers.push_back(std::move(pushed).ValueOrDie());
|
scoped_buffers.push_back(std::move(pushed).ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,7 +255,8 @@ LocalComputation::LocalComputation(Computation computation)
|
|||||||
: computation_(std::move(computation)) {}
|
: computation_(std::move(computation)) {}
|
||||||
|
|
||||||
StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
|
StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
|
||||||
const std::vector<Shape>& argument_shapes) {
|
const std::vector<Shape>& argument_shapes,
|
||||||
|
const ExecutableBuildOptions* build_options) {
|
||||||
std::vector<const Shape*> argument_shape_pointers;
|
std::vector<const Shape*> argument_shape_pointers;
|
||||||
argument_shape_pointers.reserve(argument_shapes.size());
|
argument_shape_pointers.reserve(argument_shapes.size());
|
||||||
for (auto& argument_shape : argument_shapes) {
|
for (auto& argument_shape : argument_shapes) {
|
||||||
@ -242,6 +265,9 @@ StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
|
|||||||
|
|
||||||
LocalClient* client = GetOrCreateLocalClient();
|
LocalClient* client = GetOrCreateLocalClient();
|
||||||
ExecutableBuildOptions options;
|
ExecutableBuildOptions options;
|
||||||
|
if (build_options != nullptr) {
|
||||||
|
options = *build_options;
|
||||||
|
}
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto local_executable,
|
auto local_executable,
|
||||||
client->Compile(computation_, argument_shape_pointers, options));
|
client->Compile(computation_, argument_shape_pointers, options));
|
||||||
@ -363,12 +389,6 @@ LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
|
|||||||
source, init_value, scatter.computation());
|
source, init_value, scatter.computation());
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputationDataHandle LocalComputationBuilder::Select(
|
|
||||||
const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
|
|
||||||
const ComputationDataHandle& on_false) {
|
|
||||||
return builder_.Select(pred, on_true, on_false);
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputationDataHandle LocalComputationBuilder::Tuple(
|
ComputationDataHandle LocalComputationBuilder::Tuple(
|
||||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
|
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
|
||||||
return builder_.Tuple(elements);
|
return builder_.Tuple(elements);
|
||||||
@ -384,6 +404,12 @@ ComputationDataHandle LocalComputationBuilder::Dot(
|
|||||||
return builder_.Dot(lhs, rhs);
|
return builder_.Dot(lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ComputationDataHandle LocalComputationBuilder::DotGeneral(
|
||||||
|
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||||
|
const DotDimensionNumbers& dimension_numbers) {
|
||||||
|
return builder_.DotGeneral(lhs, rhs, dimension_numbers);
|
||||||
|
}
|
||||||
|
|
||||||
ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated(
|
ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated(
|
||||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||||
@ -483,6 +509,15 @@ ComputationDataHandle LocalComputationBuilder::While(
|
|||||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
|
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
|
||||||
(lhs, rhs, broadcast_dimensions))
|
(lhs, rhs, broadcast_dimensions))
|
||||||
|
|
||||||
|
#define _FORWARD_TRIOP(method_name) \
|
||||||
|
_FORWARD( \
|
||||||
|
method_name, ComputationDataHandle, \
|
||||||
|
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
|
||||||
|
const ComputationDataHandle& ehs), \
|
||||||
|
(lhs, rhs, ehs))
|
||||||
|
|
||||||
|
_FORWARD_TRIOP(Select)
|
||||||
|
_FORWARD_TRIOP(Clamp)
|
||||||
_FORWARD_BINOP(Eq)
|
_FORWARD_BINOP(Eq)
|
||||||
_FORWARD_BINOP(Ne)
|
_FORWARD_BINOP(Ne)
|
||||||
_FORWARD_BINOP(Ge)
|
_FORWARD_BINOP(Ge)
|
||||||
@ -503,6 +538,7 @@ _FORWARD_UNOP(Abs)
|
|||||||
_FORWARD_UNOP(Exp)
|
_FORWARD_UNOP(Exp)
|
||||||
_FORWARD_UNOP(Floor)
|
_FORWARD_UNOP(Floor)
|
||||||
_FORWARD_UNOP(Ceil)
|
_FORWARD_UNOP(Ceil)
|
||||||
|
_FORWARD_UNOP(Round)
|
||||||
_FORWARD_UNOP(Log)
|
_FORWARD_UNOP(Log)
|
||||||
_FORWARD_UNOP(Sign)
|
_FORWARD_UNOP(Sign)
|
||||||
_FORWARD_UNOP(Cos)
|
_FORWARD_UNOP(Cos)
|
||||||
@ -519,6 +555,7 @@ _FORWARD_UNOP(Sort)
|
|||||||
#undef _FORWARD
|
#undef _FORWARD
|
||||||
#undef _FORWARD_UNOP
|
#undef _FORWARD_UNOP
|
||||||
#undef _FORWARD_BINOP
|
#undef _FORWARD_BINOP
|
||||||
|
#undef _FORWARD_TRIOP
|
||||||
|
|
||||||
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
|
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
|
||||||
delete local_shaped_buffer;
|
delete local_shaped_buffer;
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
@ -58,7 +59,9 @@ StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
|
|||||||
// client.
|
// client.
|
||||||
class LocalShapedBuffer {
|
class LocalShapedBuffer {
|
||||||
public:
|
public:
|
||||||
static LocalShapedBuffer* FromLiteral(const Literal& argument);
|
static LocalShapedBuffer* FromLiteral(
|
||||||
|
const Literal& argument,
|
||||||
|
const tensorflow::gtl::optional<Shape>& shape_with_layout);
|
||||||
LocalShapedBuffer(std::unique_ptr<ScopedShapedBuffer> shaped_buffer);
|
LocalShapedBuffer(std::unique_ptr<ScopedShapedBuffer> shaped_buffer);
|
||||||
const std::unique_ptr<ScopedShapedBuffer>& shaped_buffer() const;
|
const std::unique_ptr<ScopedShapedBuffer>& shaped_buffer() const;
|
||||||
std::unique_ptr<Literal> ToLiteral() const;
|
std::unique_ptr<Literal> ToLiteral() const;
|
||||||
@ -76,8 +79,15 @@ class LocalShapedBuffer {
|
|||||||
class CompiledLocalComputation {
|
class CompiledLocalComputation {
|
||||||
public:
|
public:
|
||||||
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
|
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
|
||||||
|
|
||||||
|
// Execute the computation with the given argument literals, and
|
||||||
|
// with optionally-specified argument layouts. The literals will be
|
||||||
|
// re-laid out according to the corresponding elements of
|
||||||
|
// shapes_with_layout.
|
||||||
StatusOr<std::unique_ptr<Literal> > Execute(
|
StatusOr<std::unique_ptr<Literal> > Execute(
|
||||||
const std::vector<Literal>& arguments);
|
const std::vector<Literal>& arguments,
|
||||||
|
const std::vector<tensorflow::gtl::optional<Shape> >& shapes_with_layout);
|
||||||
|
|
||||||
LocalShapedBuffer* ExecuteWithShapedBuffers(
|
LocalShapedBuffer* ExecuteWithShapedBuffers(
|
||||||
tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
|
tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
|
||||||
|
|
||||||
@ -93,7 +103,8 @@ class LocalComputation {
|
|||||||
public:
|
public:
|
||||||
LocalComputation(Computation computation);
|
LocalComputation(Computation computation);
|
||||||
StatusOr<CompiledLocalComputation*> Compile(
|
StatusOr<CompiledLocalComputation*> Compile(
|
||||||
const std::vector<Shape>& argument_shapes);
|
const std::vector<Shape>& argument_shapes,
|
||||||
|
const ExecutableBuildOptions* build_options);
|
||||||
const Computation& computation() const;
|
const Computation& computation() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -172,10 +183,6 @@ class LocalComputationBuilder {
|
|||||||
const ComputationDataHandle& source,
|
const ComputationDataHandle& source,
|
||||||
const ComputationDataHandle& init_value, const LocalComputation& scatter);
|
const ComputationDataHandle& init_value, const LocalComputation& scatter);
|
||||||
|
|
||||||
ComputationDataHandle Select(const ComputationDataHandle& pred,
|
|
||||||
const ComputationDataHandle& on_true,
|
|
||||||
const ComputationDataHandle& on_false);
|
|
||||||
|
|
||||||
ComputationDataHandle Tuple(
|
ComputationDataHandle Tuple(
|
||||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
|
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
|
||||||
|
|
||||||
@ -185,6 +192,10 @@ class LocalComputationBuilder {
|
|||||||
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
|
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
|
||||||
const ComputationDataHandle& rhs);
|
const ComputationDataHandle& rhs);
|
||||||
|
|
||||||
|
ComputationDataHandle DotGeneral(
|
||||||
|
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||||
|
const DotDimensionNumbers& dimension_numbers);
|
||||||
|
|
||||||
ComputationDataHandle ConvGeneralDilated(
|
ComputationDataHandle ConvGeneralDilated(
|
||||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||||
@ -252,6 +263,14 @@ class LocalComputationBuilder {
|
|||||||
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
|
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
|
||||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
|
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
|
||||||
|
|
||||||
|
#define _FORWARD_TRIOP(method_name) \
|
||||||
|
_FORWARD( \
|
||||||
|
method_name, ComputationDataHandle, \
|
||||||
|
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
|
||||||
|
const ComputationDataHandle& ehs))
|
||||||
|
|
||||||
|
_FORWARD_TRIOP(Select)
|
||||||
|
_FORWARD_TRIOP(Clamp)
|
||||||
_FORWARD_BINOP(Eq)
|
_FORWARD_BINOP(Eq)
|
||||||
_FORWARD_BINOP(Ne)
|
_FORWARD_BINOP(Ne)
|
||||||
_FORWARD_BINOP(Ge)
|
_FORWARD_BINOP(Ge)
|
||||||
@ -272,6 +291,7 @@ class LocalComputationBuilder {
|
|||||||
_FORWARD_UNOP(Exp)
|
_FORWARD_UNOP(Exp)
|
||||||
_FORWARD_UNOP(Floor)
|
_FORWARD_UNOP(Floor)
|
||||||
_FORWARD_UNOP(Ceil)
|
_FORWARD_UNOP(Ceil)
|
||||||
|
_FORWARD_UNOP(Round)
|
||||||
_FORWARD_UNOP(Log)
|
_FORWARD_UNOP(Log)
|
||||||
_FORWARD_UNOP(Sign)
|
_FORWARD_UNOP(Sign)
|
||||||
_FORWARD_UNOP(Cos)
|
_FORWARD_UNOP(Cos)
|
||||||
@ -288,6 +308,7 @@ class LocalComputationBuilder {
|
|||||||
#undef _FORWARD
|
#undef _FORWARD
|
||||||
#undef _FORWARD_UNOP
|
#undef _FORWARD_UNOP
|
||||||
#undef _FORWARD_BINOP
|
#undef _FORWARD_BINOP
|
||||||
|
#undef _FORWARD_TRIOP
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ComputationBuilder builder_;
|
ComputationBuilder builder_;
|
||||||
|
@ -27,12 +27,14 @@ limitations under the License.
|
|||||||
// ArraySlice<ComputationDataHandle> <- sequence of int
|
// ArraySlice<ComputationDataHandle> <- sequence of int
|
||||||
// Literal <-> (nested tuple of) numpy ndarray
|
// Literal <-> (nested tuple of) numpy ndarray
|
||||||
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
|
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
|
||||||
// Shape <-> pair holding (dtype, dimensions)
|
// Shape -> pair holding (dtype, dimensions)
|
||||||
// std::vector<Shape> <- sequence of shape information pairs
|
// <- object duck-typed as xla_client.Shape
|
||||||
|
// std::vector<Shape> <- sequence of xla_client.Shape objects
|
||||||
// PrimitiveType <- int
|
// PrimitiveType <- int
|
||||||
// ArraySlice<pair<int64, in64>> <- sequence of int pairs
|
// ArraySlice<pair<int64, in64>> <- sequence of int pairs
|
||||||
// PaddingConfig proto <- corresponding Python proto
|
// PaddingConfig proto <- corresponding Python proto
|
||||||
// ConvolutionDimensionNumbers proto <- corresponding Python proto
|
// ConvolutionDimensionNumbers proto <- corresponding Python proto
|
||||||
|
// DotDimensionNumbers proto <- corresponding Python proto
|
||||||
//
|
//
|
||||||
// Arrows indicate whether a conversion only ever occurs in one
|
// Arrows indicate whether a conversion only ever occurs in one
|
||||||
// direction, or whether it is maintained bidirectionally.
|
// direction, or whether it is maintained bidirectionally.
|
||||||
@ -55,7 +57,7 @@ limitations under the License.
|
|||||||
// translates to a tuple-shaped XLA Literal, whose component subshapes
|
// translates to a tuple-shaped XLA Literal, whose component subshapes
|
||||||
// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
|
// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
|
||||||
//
|
//
|
||||||
// The Python objects corresponding to C++ Shapes have the type:
|
// Shapes output by C++ become Python objects with the type:
|
||||||
//
|
//
|
||||||
// T = (dtype, S)
|
// T = (dtype, S)
|
||||||
// S = DIMENSIONS | TUPLE_SHAPES
|
// S = DIMENSIONS | TUPLE_SHAPES
|
||||||
@ -176,6 +178,16 @@ tensorflow::ImportNumpy();
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
|
||||||
|
if ($1.ok()) {
|
||||||
|
std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
|
||||||
|
$result = numpy::PyObjectFromXlaLiteral(*value);
|
||||||
|
} else {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
%typemap(out) StatusOr<xla::swig::LocalComputation*> {
|
%typemap(out) StatusOr<xla::swig::LocalComputation*> {
|
||||||
if ($1.ok()) {
|
if ($1.ok()) {
|
||||||
auto* value = $1.ValueOrDie();
|
auto* value = $1.ValueOrDie();
|
||||||
@ -343,15 +355,31 @@ tensorflow::ImportNumpy();
|
|||||||
// Shape
|
// Shape
|
||||||
|
|
||||||
%typemap(in) const Shape& (Shape temp) {
|
%typemap(in) const Shape& (Shape temp) {
|
||||||
Status shape_status = numpy::CheckPyShapeInfo($input);
|
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
|
||||||
if (!shape_status.ok()) {
|
if (!statusor.ok()) {
|
||||||
PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str());
|
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
temp = numpy::XlaShapeFromPyShapeInfo($input);
|
temp = std::move(statusor).ValueOrDie();
|
||||||
$1 = &temp;
|
$1 = &temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
%typemap(in) const tensorflow::gtl::optional<Shape>& (
|
||||||
|
tensorflow::gtl::optional<Shape> temp) {
|
||||||
|
if ($input == Py_None) {
|
||||||
|
temp = tensorflow::gtl::nullopt;
|
||||||
|
$1 = &temp;
|
||||||
|
} else {
|
||||||
|
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
|
||||||
|
if (!statusor.ok()) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
temp = std::move(statusor).ValueOrDie();
|
||||||
|
$1 = &temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
%typemap(out) std::unique_ptr<Shape> {
|
%typemap(out) std::unique_ptr<Shape> {
|
||||||
$result = numpy::PyShapeInfoFromXlaShape(*$1);
|
$result = numpy::PyShapeInfoFromXlaShape(*$1);
|
||||||
}
|
}
|
||||||
@ -364,14 +392,37 @@ tensorflow::ImportNumpy();
|
|||||||
const int size = PySequence_Size($input);
|
const int size = PySequence_Size($input);
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
PyObject* o = PySequence_GetItem($input, i);
|
PyObject* o = PySequence_GetItem($input, i);
|
||||||
Status shape_status = numpy::CheckPyShapeInfo(o);
|
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
|
||||||
if (!shape_status.ok()) {
|
|
||||||
PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str());
|
|
||||||
Py_DECREF(o);
|
Py_DECREF(o);
|
||||||
|
if (!statusor.ok()) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
temps.push_back(numpy::XlaShapeFromPyShapeInfo(o));
|
temps.push_back(statusor.ConsumeValueOrDie());
|
||||||
|
}
|
||||||
|
$1 = &temps;
|
||||||
|
}
|
||||||
|
|
||||||
|
%typemap(in) const std::vector<tensorflow::gtl::optional<Shape> >& (
|
||||||
|
std::vector<tensorflow::gtl::optional<Shape> > temps) {
|
||||||
|
if (!PySequence_Check($input)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
const int size = PySequence_Size($input);
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
PyObject* o = PySequence_GetItem($input, i);
|
||||||
|
if (o == Py_None) {
|
||||||
|
temps.push_back(tensorflow::gtl::nullopt);
|
||||||
|
} else {
|
||||||
|
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
|
||||||
Py_DECREF(o);
|
Py_DECREF(o);
|
||||||
|
if (!statusor.ok()) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
temps.push_back(statusor.ConsumeValueOrDie());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
$1 = &temps;
|
$1 = &temps;
|
||||||
}
|
}
|
||||||
@ -461,6 +512,135 @@ tensorflow::ImportNumpy();
|
|||||||
$1 = temps;
|
$1 = temps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DotDimensionNumbers
|
||||||
|
|
||||||
|
%typemap(in) const DotDimensionNumbers&
|
||||||
|
(DotDimensionNumbers dimension_numbers) {
|
||||||
|
int length;
|
||||||
|
|
||||||
|
/* lhs_contracting_dimensions */
|
||||||
|
PyObject* lhs_contracting_dimensions = PyObject_GetAttrString(
|
||||||
|
$input, "lhs_contracting_dimensions");
|
||||||
|
if (!lhs_contracting_dimensions) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
length = PySequence_Size(lhs_contracting_dimensions);
|
||||||
|
if (length == -1) {
|
||||||
|
Py_DECREF(lhs_contracting_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i);
|
||||||
|
if (!item) {
|
||||||
|
Py_DECREF(lhs_contracting_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||||
|
if (dimension == -1 && PyErr_Occurred()) {
|
||||||
|
Py_DECREF(item);
|
||||||
|
Py_DECREF(lhs_contracting_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
dimension_numbers.add_lhs_contracting_dimensions(dimension);
|
||||||
|
Py_DECREF(item);
|
||||||
|
}
|
||||||
|
Py_DECREF(lhs_contracting_dimensions);
|
||||||
|
|
||||||
|
/* rhs_contracting_dimensions */
|
||||||
|
PyObject* rhs_contracting_dimensions = PyObject_GetAttrString(
|
||||||
|
$input, "rhs_contracting_dimensions");
|
||||||
|
if (!lhs_contracting_dimensions) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
length = PySequence_Size(rhs_contracting_dimensions);
|
||||||
|
if (length == -1) {
|
||||||
|
Py_DECREF(rhs_contracting_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i);
|
||||||
|
if (!item) {
|
||||||
|
Py_DECREF(rhs_contracting_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||||
|
if (dimension == -1 && PyErr_Occurred()) {
|
||||||
|
Py_DECREF(item);
|
||||||
|
Py_DECREF(rhs_contracting_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
dimension_numbers.add_rhs_contracting_dimensions(dimension);
|
||||||
|
Py_DECREF(item);
|
||||||
|
}
|
||||||
|
Py_DECREF(rhs_contracting_dimensions);
|
||||||
|
|
||||||
|
/* lhs_batch_dimensions */
|
||||||
|
PyObject* lhs_batch_dimensions = PyObject_GetAttrString(
|
||||||
|
$input, "lhs_batch_dimensions");
|
||||||
|
if (!lhs_batch_dimensions) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
length = PySequence_Size(lhs_batch_dimensions);
|
||||||
|
if (length == -1) {
|
||||||
|
Py_DECREF(lhs_batch_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i);
|
||||||
|
if (!item) {
|
||||||
|
Py_DECREF(lhs_batch_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||||
|
if (dimension == -1 && PyErr_Occurred()) {
|
||||||
|
Py_DECREF(item);
|
||||||
|
Py_DECREF(lhs_batch_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
dimension_numbers.add_lhs_batch_dimensions(dimension);
|
||||||
|
Py_DECREF(item);
|
||||||
|
}
|
||||||
|
Py_DECREF(lhs_batch_dimensions);
|
||||||
|
|
||||||
|
/* rhs_batch_dimensions */
|
||||||
|
PyObject* rhs_batch_dimensions = PyObject_GetAttrString(
|
||||||
|
$input, "rhs_batch_dimensions");
|
||||||
|
if (!rhs_batch_dimensions) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
length = PySequence_Size(rhs_batch_dimensions);
|
||||||
|
if (length == -1) {
|
||||||
|
Py_DECREF(rhs_batch_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i);
|
||||||
|
if (!item) {
|
||||||
|
Py_DECREF(rhs_batch_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||||
|
if (dimension == -1 && PyErr_Occurred()) {
|
||||||
|
Py_DECREF(item);
|
||||||
|
Py_DECREF(rhs_batch_dimensions);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
dimension_numbers.add_rhs_batch_dimensions(dimension);
|
||||||
|
Py_DECREF(item);
|
||||||
|
}
|
||||||
|
Py_DECREF(rhs_batch_dimensions);
|
||||||
|
|
||||||
|
$1 = &dimension_numbers;
|
||||||
|
}
|
||||||
|
|
||||||
// PaddingConfig
|
// PaddingConfig
|
||||||
|
|
||||||
%typemap(in) const PaddingConfig&
|
%typemap(in) const PaddingConfig&
|
||||||
@ -623,6 +803,30 @@ tensorflow::ImportNumpy();
|
|||||||
$1 = &dimension_numbers;
|
$1 = &dimension_numbers;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecutableBuildOptions
|
||||||
|
|
||||||
|
%typemap(in) const ExecutableBuildOptions*
|
||||||
|
(ExecutableBuildOptions build_options) {
|
||||||
|
if ($input == Py_None) {
|
||||||
|
$1 = NULL;
|
||||||
|
} else {
|
||||||
|
PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph");
|
||||||
|
if (!o) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
if (o != Py_None) {
|
||||||
|
if (!PyString_Check(o)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None.");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
build_options.set_generate_hlo_graph(PyString_AsString(o));
|
||||||
|
}
|
||||||
|
Py_DECREF(o);
|
||||||
|
|
||||||
|
$1 = &build_options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
%ignoreall
|
%ignoreall
|
||||||
%unignore xla;
|
%unignore xla;
|
||||||
%unignore xla::swig;
|
%unignore xla::swig;
|
||||||
@ -667,6 +871,7 @@ tensorflow::ImportNumpy();
|
|||||||
%unignore xla::swig::LocalComputationBuilder::Call;
|
%unignore xla::swig::LocalComputationBuilder::Call;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Transpose;
|
%unignore xla::swig::LocalComputationBuilder::Transpose;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Rev;
|
%unignore xla::swig::LocalComputationBuilder::Rev;
|
||||||
|
%unignore xla::swig::LocalComputationBuilder::Clamp;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Map;
|
%unignore xla::swig::LocalComputationBuilder::Map;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Reduce;
|
%unignore xla::swig::LocalComputationBuilder::Reduce;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding;
|
%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding;
|
||||||
@ -681,6 +886,7 @@ tensorflow::ImportNumpy();
|
|||||||
%unignore xla::swig::LocalComputationBuilder::Lt;
|
%unignore xla::swig::LocalComputationBuilder::Lt;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Le;
|
%unignore xla::swig::LocalComputationBuilder::Le;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Dot;
|
%unignore xla::swig::LocalComputationBuilder::Dot;
|
||||||
|
%unignore xla::swig::LocalComputationBuilder::DotGeneral;
|
||||||
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
|
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Add;
|
%unignore xla::swig::LocalComputationBuilder::Add;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Sub;
|
%unignore xla::swig::LocalComputationBuilder::Sub;
|
||||||
@ -696,6 +902,7 @@ tensorflow::ImportNumpy();
|
|||||||
%unignore xla::swig::LocalComputationBuilder::Exp;
|
%unignore xla::swig::LocalComputationBuilder::Exp;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Floor;
|
%unignore xla::swig::LocalComputationBuilder::Floor;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Ceil;
|
%unignore xla::swig::LocalComputationBuilder::Ceil;
|
||||||
|
%unignore xla::swig::LocalComputationBuilder::Round;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Log;
|
%unignore xla::swig::LocalComputationBuilder::Log;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Sign;
|
%unignore xla::swig::LocalComputationBuilder::Sign;
|
||||||
%unignore xla::swig::LocalComputationBuilder::Cos;
|
%unignore xla::swig::LocalComputationBuilder::Cos;
|
||||||
|
@ -176,86 +176,108 @@ static string PyObjectCppRepr(PyObject* o) {
|
|||||||
return ExtractStringAndDecref(r);
|
return ExtractStringAndDecref(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CheckPyShapeInfo(PyObject* o) {
|
StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
|
||||||
auto error = [o](const string& prefix) {
|
auto error = [o](const string& prefix) {
|
||||||
return InvalidArgument("%s; got %s", prefix.c_str(),
|
return InvalidArgument("%s; got %s", prefix.c_str(),
|
||||||
PyObjectCppRepr(o).c_str());
|
PyObjectCppRepr(o).c_str());
|
||||||
};
|
};
|
||||||
// The object is a tuple (a pair)
|
|
||||||
if (!PyTuple_Check(o)) {
|
|
||||||
return error("Shape record must be a tuple");
|
|
||||||
}
|
|
||||||
if (PyTuple_Size(o) != 2) {
|
|
||||||
return error("Shape record tuple must be of length 2");
|
|
||||||
}
|
|
||||||
|
|
||||||
// It has a first element, which is a numpy dtype object
|
auto get_attr = [o, &error](const string& field) -> StatusOr<PyObject*> {
|
||||||
PyObject* first = PyTuple_GetItem(o, 0);
|
PyObject* result =
|
||||||
if (first == nullptr) {
|
PyObject_GetAttrString(o, const_cast<char*>(field.c_str()));
|
||||||
return error("Tuple has no item 0 (shape dtype)");
|
if (result == nullptr) {
|
||||||
|
return error(tensorflow::strings::StrCat(
|
||||||
|
"Failed to get attribute of Shape object:", field));
|
||||||
}
|
}
|
||||||
if (first->ob_type != &PyArrayDescr_Type) {
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto call_method = [o, &error](const string& method) -> StatusOr<PyObject*> {
|
||||||
|
PyObject* result =
|
||||||
|
PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
|
||||||
|
if (result == nullptr) {
|
||||||
|
return error(tensorflow::strings::StrCat(
|
||||||
|
"Failed to call method of shape object:", method));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
PyObject* np_type;
|
||||||
|
TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype"));
|
||||||
|
if (np_type->ob_type != &PyArrayDescr_Type) {
|
||||||
|
return error("Shape attribute np_dtype is not an integer numpy dtype");
|
||||||
|
}
|
||||||
|
if (!NumpyTypeIsValid(NumpyTypenum(np_type))) {
|
||||||
|
return error("Shape attribute np_dtype is not a valid integer numpy dtype");
|
||||||
|
}
|
||||||
|
const PrimitiveType element_type =
|
||||||
|
NumpyTypeToPrimitiveType(NumpyTypenum(np_type));
|
||||||
|
Py_DECREF(np_type);
|
||||||
|
|
||||||
|
if (element_type == TUPLE) {
|
||||||
|
PyObject* py_subshapes;
|
||||||
|
TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes"));
|
||||||
|
if (!PyTuple_Check(py_subshapes)) {
|
||||||
return error(
|
return error(
|
||||||
"Shape record does not have a numpy dtype as its first element");
|
"Return value of Shape method tuple_shapes() is not a tuple");
|
||||||
}
|
}
|
||||||
const int np_type = NumpyTypenum(first);
|
const int length = PyTuple_Size(py_subshapes);
|
||||||
if (!NumpyTypeIsValid(np_type)) {
|
|
||||||
return error("Shape record has an invalid integer dtype");
|
|
||||||
}
|
|
||||||
|
|
||||||
// It has a second element, which is a tuple, either of shape
|
|
||||||
// records or of Python ints
|
|
||||||
PyObject* second = PyTuple_GetItem(o, 1);
|
|
||||||
if (!second) {
|
|
||||||
return error("Tuple has no item 0 (shape dimensions)");
|
|
||||||
}
|
|
||||||
if (!PyTuple_Check(second)) {
|
|
||||||
return error("Shape record does not have a tuple as its second element");
|
|
||||||
}
|
|
||||||
const int length = PyTuple_Size(second);
|
|
||||||
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
|
|
||||||
for (int i = 0; i < length; i++) {
|
|
||||||
PyObject* dimension = PyTuple_GetItem(second, i);
|
|
||||||
if (element_type == TUPLE) {
|
|
||||||
VLOG(3) << "element_type is tuple, checking member: " << i;
|
|
||||||
Status result = CheckPyShapeInfo(dimension);
|
|
||||||
if (!result.ok()) {
|
|
||||||
return AddStatus(
|
|
||||||
result, tensorflow::strings::StrCat("Validating tuple member ", i,
|
|
||||||
" of ", PyObjectCppRepr(o)));
|
|
||||||
}
|
|
||||||
} else if (!CheckPyIntOrLong(dimension)) {
|
|
||||||
return error("Non-tuple shape record has a non-integer dimension");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Precondition: CheckPyShapeInfo(o)
|
|
||||||
Shape XlaShapeFromPyShapeInfo(PyObject* o) {
|
|
||||||
const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0));
|
|
||||||
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
|
|
||||||
PyObject* py_dimensions = PyTuple_GetItem(o, 1);
|
|
||||||
const int length = PyTuple_Size(py_dimensions);
|
|
||||||
if (element_type == TUPLE) {
|
|
||||||
std::vector<Shape> subshapes;
|
std::vector<Shape> subshapes;
|
||||||
subshapes.reserve(length);
|
subshapes.reserve(length);
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
subshapes.push_back(
|
TF_ASSIGN_OR_RETURN(
|
||||||
XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i)));
|
const Shape& subshape,
|
||||||
|
XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i)));
|
||||||
|
subshapes.push_back(subshape);
|
||||||
}
|
}
|
||||||
|
Py_DECREF(py_subshapes);
|
||||||
return ShapeUtil::MakeTupleShape(subshapes);
|
return ShapeUtil::MakeTupleShape(subshapes);
|
||||||
} else {
|
} else {
|
||||||
|
PyObject* py_dimensions;
|
||||||
|
PyObject* py_minor_to_major;
|
||||||
|
TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions"));
|
||||||
|
TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major"));
|
||||||
|
if (!PyTuple_Check(py_dimensions)) {
|
||||||
|
return error("Return value of Shape method dimensions() is not a tuple");
|
||||||
|
}
|
||||||
|
if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) {
|
||||||
|
return error(
|
||||||
|
"Return value of Shape method minor_to_major() is neither a tuple "
|
||||||
|
"nor None");
|
||||||
|
}
|
||||||
|
const int length = PyTuple_Size(py_dimensions);
|
||||||
|
if (py_minor_to_major != Py_None &&
|
||||||
|
length != PyTuple_Size(py_minor_to_major)) {
|
||||||
|
return error(
|
||||||
|
"Shape methods dimensions() and minor_to_major() return "
|
||||||
|
"different-length tuples");
|
||||||
|
}
|
||||||
std::vector<int64> dimensions(length);
|
std::vector<int64> dimensions(length);
|
||||||
|
std::vector<int64> minor_to_major(length);
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
|
dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
|
||||||
if (dimensions[i] == -1) {
|
if (dimensions[i] == -1 && PyErr_Occurred()) {
|
||||||
CHECK(!PyErr_Occurred());
|
return error("Dimension is not an int");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (py_minor_to_major != Py_None) {
|
||||||
|
minor_to_major[i] =
|
||||||
|
PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i));
|
||||||
|
if (minor_to_major[i] == -1 && PyErr_Occurred()) {
|
||||||
|
return error("Minor-to-major value is not an int");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
bool with_layout = py_minor_to_major != Py_None;
|
||||||
|
Py_DECREF(py_dimensions);
|
||||||
|
Py_DECREF(py_minor_to_major);
|
||||||
|
if (with_layout) {
|
||||||
|
return ShapeUtil::MakeShapeWithLayout(element_type, dimensions,
|
||||||
|
minor_to_major);
|
||||||
|
} else {
|
||||||
return ShapeUtil::MakeShape(element_type, dimensions);
|
return ShapeUtil::MakeShape(element_type, dimensions);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper that retrieves the member with attr_name, stringifies it if is not
|
// Helper that retrieves the member with attr_name, stringifies it if is not
|
||||||
|
@ -56,15 +56,11 @@ bool NumpyTypeIsValid(int np_type);
|
|||||||
// The return value is a new reference.
|
// The return value is a new reference.
|
||||||
PyObject* PyShapeInfoFromXlaShape(const Shape& shape);
|
PyObject* PyShapeInfoFromXlaShape(const Shape& shape);
|
||||||
|
|
||||||
// Returns the outcome of a best-effort check that the Python object
|
// Converts a Python object with a method interface mathing that of
|
||||||
// is a pair of the form (numpy dtype, dimensions), as produced by
|
// xla_client.Shape into an XLA Shape object.
|
||||||
// PyShapeInfoFromXlaShape.
|
|
||||||
Status CheckPyShapeInfo(PyObject* o);
|
|
||||||
|
|
||||||
// Performs the inverse conversion to that of PyShapeInfoFromXlaShape.
|
|
||||||
//
|
//
|
||||||
// The return value is a new reference.
|
// The return value is a new reference.
|
||||||
Shape XlaShapeFromPyShapeInfo(PyObject* o);
|
StatusOr<Shape> XlaShapeFromPyShape(PyObject* o);
|
||||||
|
|
||||||
// Converts a PyObject that represents operation metadata into protocol buffer
|
// Converts a PyObject that represents operation metadata into protocol buffer
|
||||||
// form.
|
// form.
|
||||||
|
@ -89,6 +89,7 @@ _UNARY_OPS = [
|
|||||||
'Abs',
|
'Abs',
|
||||||
'Exp',
|
'Exp',
|
||||||
'Floor',
|
'Floor',
|
||||||
|
'Round',
|
||||||
'Ceil',
|
'Ceil',
|
||||||
'Log',
|
'Log',
|
||||||
'Sign',
|
'Sign',
|
||||||
@ -155,9 +156,14 @@ class LocalBuffer(object):
|
|||||||
self._delete = c_api.DeleteLocalShapedBuffer
|
self._delete = c_api.DeleteLocalShapedBuffer
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_py(npval):
|
def from_py(npval, layout_fn=None):
|
||||||
npval = require_numpy_array_layout(npval)
|
npval = require_numpy_array_layout(npval)
|
||||||
return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval))
|
if layout_fn:
|
||||||
|
shape = Shape.from_numpy(npval)
|
||||||
|
shape = shape.map_leaves(layout_fn)
|
||||||
|
else:
|
||||||
|
shape = None
|
||||||
|
return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape))
|
||||||
|
|
||||||
def to_py(self):
|
def to_py(self):
|
||||||
return self.c_local_shaped_buffer.ToLiteral()
|
return self.c_local_shaped_buffer.ToLiteral()
|
||||||
@ -182,13 +188,17 @@ class Shape(object):
|
|||||||
represents an XLA tuple.
|
represents an XLA tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, np_dtype, dimensions):
|
def __init__(self, np_dtype, dimensions, minor_to_major=None):
|
||||||
|
assert isinstance(dimensions, tuple)
|
||||||
self.np_dtype = np_dtype
|
self.np_dtype = np_dtype
|
||||||
self._dimensions = dimensions
|
self._dimensions = dimensions
|
||||||
|
self._minor_to_major = minor_to_major
|
||||||
|
self._check_minor_to_major()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return 'xla_client.Shape(np_dtype={!r}, dimensions={!r})'.format(
|
return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, '
|
||||||
self.np_dtype, self._dimensions)
|
'minor_to_major={!r})').format(self.np_dtype, self._dimensions,
|
||||||
|
self._minor_to_major)
|
||||||
|
|
||||||
def element_type(self):
|
def element_type(self):
|
||||||
return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
|
return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
|
||||||
@ -201,11 +211,49 @@ class Shape(object):
|
|||||||
raise ValueError('Tuple shape has no dimensions')
|
raise ValueError('Tuple shape has no dimensions')
|
||||||
return self._dimensions
|
return self._dimensions
|
||||||
|
|
||||||
|
def minor_to_major(self):
|
||||||
|
return self._minor_to_major
|
||||||
|
|
||||||
def tuple_shapes(self):
|
def tuple_shapes(self):
|
||||||
if not self.is_tuple():
|
if not self.is_tuple():
|
||||||
raise ValueError('Shape is not a tuple shape')
|
raise ValueError('Shape is not a tuple shape')
|
||||||
return self._dimensions
|
return self._dimensions
|
||||||
|
|
||||||
|
def rank(self):
|
||||||
|
return len(self.dimensions())
|
||||||
|
|
||||||
|
def map_leaves(self, f):
|
||||||
|
"""Map f over each leaf-level array subshape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: The function to apply. Whenever f returns None, the identity is
|
||||||
|
applied instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new Shape with the mapped leaves.
|
||||||
|
"""
|
||||||
|
if self.is_tuple():
|
||||||
|
children = tuple(child.map_leaves(f) for child in self.tuple_shapes())
|
||||||
|
return Shape(np.dtype('O'), children)
|
||||||
|
else:
|
||||||
|
mapped = f(self)
|
||||||
|
return self if mapped is None else mapped
|
||||||
|
|
||||||
|
def _check_minor_to_major(self):
|
||||||
|
mtm = self._minor_to_major
|
||||||
|
if self.is_tuple():
|
||||||
|
assert mtm is None, self
|
||||||
|
if mtm is not None:
|
||||||
|
assert self.rank() == len(mtm), self
|
||||||
|
assert sorted(mtm) == range(len(mtm)), self
|
||||||
|
|
||||||
|
def update_minor_to_major(self, minor_to_major):
|
||||||
|
if not isinstance(minor_to_major, tuple):
|
||||||
|
raise TypeError('minor_to_major must be a tuple')
|
||||||
|
updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major)
|
||||||
|
updated._check_minor_to_major() # pylint: disable=protected-access
|
||||||
|
return updated
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_numpy(npval):
|
def from_numpy(npval):
|
||||||
|
|
||||||
@ -222,23 +270,10 @@ def _wrap_shape(shape_info):
|
|||||||
dtype, dims = shape_info
|
dtype, dims = shape_info
|
||||||
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
|
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
|
||||||
if element_type == xla_data_pb2.TUPLE:
|
if element_type == xla_data_pb2.TUPLE:
|
||||||
dims = [_wrap_shape(subshape_info) for subshape_info in dims]
|
dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims)
|
||||||
return Shape(dtype, dims)
|
return Shape(dtype, dims)
|
||||||
|
|
||||||
|
|
||||||
def _unwrap_shape(shape):
|
|
||||||
if shape.is_tuple():
|
|
||||||
components = tuple(
|
|
||||||
_unwrap_shape(subshape) for subshape in shape.tuple_shapes())
|
|
||||||
else:
|
|
||||||
components = shape.dimensions()
|
|
||||||
return (shape.np_dtype, components)
|
|
||||||
|
|
||||||
|
|
||||||
def _unwrap_shapes(shapes):
|
|
||||||
return [_unwrap_shape(shape) for shape in shapes]
|
|
||||||
|
|
||||||
|
|
||||||
def _wrap_data_handle(handle):
|
def _wrap_data_handle(handle):
|
||||||
cdh = xla_data_pb2.ComputationDataHandle()
|
cdh = xla_data_pb2.ComputationDataHandle()
|
||||||
cdh.handle = handle
|
cdh.handle = handle
|
||||||
@ -260,6 +295,17 @@ def require_numpy_array_layout(value):
|
|||||||
return np.require(value, requirements=['C', 'A'])
|
return np.require(value, requirements=['C', 'A'])
|
||||||
|
|
||||||
|
|
||||||
|
class CompileOptions(object):
|
||||||
|
"""Python object for XLA compile options.
|
||||||
|
|
||||||
|
These options can be passed to the 'compile' step when using a local XLA
|
||||||
|
client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.generate_hlo_graph = None
|
||||||
|
|
||||||
|
|
||||||
def transfer_to_infeed(value, replica_number=None):
|
def transfer_to_infeed(value, replica_number=None):
|
||||||
"""Transfers the given value into the XLA infeed queue.
|
"""Transfers the given value into the XLA infeed queue.
|
||||||
|
|
||||||
@ -291,8 +337,7 @@ def transfer_from_outfeed(shape, replica_number=None):
|
|||||||
Returns:
|
Returns:
|
||||||
The literal value that is produced from the outfeed queue.
|
The literal value that is produced from the outfeed queue.
|
||||||
"""
|
"""
|
||||||
return c_api.TransferFromOutfeedLocalReplica(
|
return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0)
|
||||||
_unwrap_shape(shape), replica_number or 0)
|
|
||||||
|
|
||||||
|
|
||||||
class LocalComputation(object):
|
class LocalComputation(object):
|
||||||
@ -313,22 +358,39 @@ class LocalComputation(object):
|
|||||||
else:
|
else:
|
||||||
self._delete = c_api.DeleteLocalComputation
|
self._delete = c_api.DeleteLocalComputation
|
||||||
|
|
||||||
def Compile(self, argument_shapes=()):
|
def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None):
|
||||||
if self.is_compiled:
|
if self.is_compiled:
|
||||||
raise ValueError('Attempt to compile a compiled local XLA computation.')
|
raise ValueError('Attempt to compile a compiled local XLA computation.')
|
||||||
|
if layout_fn:
|
||||||
|
argument_shapes = [
|
||||||
|
shape.map_leaves(layout_fn) for shape in argument_shapes
|
||||||
|
]
|
||||||
return LocalComputation(
|
return LocalComputation(
|
||||||
self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)),
|
self.c_local_computation.Compile(argument_shapes, compile_options),
|
||||||
is_compiled=True)
|
is_compiled=True)
|
||||||
|
|
||||||
def CompileWithExampleArguments(self, arguments=()):
|
def CompileWithExampleArguments(self,
|
||||||
|
arguments=(),
|
||||||
|
compile_options=None,
|
||||||
|
layout_fn=None):
|
||||||
return self.Compile(
|
return self.Compile(
|
||||||
argument_shapes=[Shape.from_numpy(arg) for arg in arguments])
|
argument_shapes=[Shape.from_numpy(arg) for arg in arguments],
|
||||||
|
compile_options=compile_options,
|
||||||
|
layout_fn=layout_fn)
|
||||||
|
|
||||||
def Execute(self, arguments=()):
|
def Execute(self, arguments=(), layout_fn=None):
|
||||||
|
"""Execute with Python values as arguments and return value."""
|
||||||
if not self.is_compiled:
|
if not self.is_compiled:
|
||||||
raise ValueError('Cannot execute an uncompiled local XLA computation.')
|
raise ValueError('Cannot execute an uncompiled local XLA computation.')
|
||||||
|
argument_shapes = [Shape.from_numpy(arg) for arg in arguments]
|
||||||
|
if layout_fn:
|
||||||
|
argument_shapes = [
|
||||||
|
shape.map_leaves(layout_fn) for shape in argument_shapes
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
argument_shapes = [None for shape in argument_shapes]
|
||||||
arguments = tuple(map(require_numpy_array_layout, arguments))
|
arguments = tuple(map(require_numpy_array_layout, arguments))
|
||||||
return self.c_local_computation.Execute(arguments)
|
return self.c_local_computation.Execute(arguments, argument_shapes)
|
||||||
|
|
||||||
def ExecuteWithLocalBuffers(self, arguments=()):
|
def ExecuteWithLocalBuffers(self, arguments=()):
|
||||||
"""Execute with LocalBuffer arguments and return value."""
|
"""Execute with LocalBuffer arguments and return value."""
|
||||||
@ -384,7 +446,7 @@ class ComputationBuilder(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A ComputationDataHandle message.
|
A ComputationDataHandle message.
|
||||||
"""
|
"""
|
||||||
return _wrap_data_handle(self._client.Infeed(_unwrap_shape(shape)))
|
return _wrap_data_handle(self._client.Infeed(shape))
|
||||||
|
|
||||||
def Outfeed(self, operand):
|
def Outfeed(self, operand):
|
||||||
"""Enqueues an outfeed op onto the computation.
|
"""Enqueues an outfeed op onto the computation.
|
||||||
@ -393,7 +455,7 @@ class ComputationBuilder(object):
|
|||||||
outfeed queue for subsequent dequeue via the client API.
|
outfeed queue for subsequent dequeue via the client API.
|
||||||
"""
|
"""
|
||||||
self._client.Outfeed(
|
self._client.Outfeed(
|
||||||
_unwrap_data_handle(operand), _unwrap_shape(self.GetShape(operand)),
|
_unwrap_data_handle(operand), self.GetShape(operand),
|
||||||
''.encode('utf-8'))
|
''.encode('utf-8'))
|
||||||
|
|
||||||
def Constant(self, value):
|
def Constant(self, value):
|
||||||
@ -484,8 +546,7 @@ class ComputationBuilder(object):
|
|||||||
parameter_num = next(self._parameter_numbering)
|
parameter_num = next(self._parameter_numbering)
|
||||||
|
|
||||||
return _wrap_data_handle(
|
return _wrap_data_handle(
|
||||||
self._client.Parameter(
|
self._client.Parameter(parameter_num, shape, name.encode('utf8')))
|
||||||
parameter_num, _unwrap_shape(shape), name.encode('utf8')))
|
|
||||||
|
|
||||||
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
|
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
|
||||||
"""Enqueues a Parameter op onto the computation.
|
"""Enqueues a Parameter op onto the computation.
|
||||||
@ -606,6 +667,13 @@ class ComputationBuilder(object):
|
|||||||
return _wrap_data_handle(
|
return _wrap_data_handle(
|
||||||
self._client.Rev(_unwrap_data_handle(operand), dimensions))
|
self._client.Rev(_unwrap_data_handle(operand), dimensions))
|
||||||
|
|
||||||
|
def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin
|
||||||
|
"""Clamp op."""
|
||||||
|
return _wrap_data_handle(
|
||||||
|
self._client.Clamp(_unwrap_data_handle(min),
|
||||||
|
_unwrap_data_handle(operand),
|
||||||
|
_unwrap_data_handle(max)))
|
||||||
|
|
||||||
def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
|
def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
|
||||||
padding, source, init_value, scatter):
|
padding, source, init_value, scatter):
|
||||||
"""Select and scatter op, used by the gradient of ReduceWindow.
|
"""Select and scatter op, used by the gradient of ReduceWindow.
|
||||||
@ -825,8 +893,7 @@ class ComputationBuilder(object):
|
|||||||
shape = Shape(self.GetShape(mu).np_dtype, dims)
|
shape = Shape(self.GetShape(mu).np_dtype, dims)
|
||||||
return _wrap_data_handle(
|
return _wrap_data_handle(
|
||||||
self._client.RngNormal(
|
self._client.RngNormal(
|
||||||
_unwrap_data_handle(mu), _unwrap_data_handle(sigma),
|
_unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape))
|
||||||
_unwrap_shape(shape)))
|
|
||||||
|
|
||||||
def RngUniform(self, a, b, dims):
|
def RngUniform(self, a, b, dims):
|
||||||
"""Enqueues an RngUniform operation onto the computation.
|
"""Enqueues an RngUniform operation onto the computation.
|
||||||
@ -846,8 +913,7 @@ class ComputationBuilder(object):
|
|||||||
shape = Shape(self.GetShape(a).np_dtype, dims)
|
shape = Shape(self.GetShape(a).np_dtype, dims)
|
||||||
return _wrap_data_handle(
|
return _wrap_data_handle(
|
||||||
self._client.RngUniform(
|
self._client.RngUniform(
|
||||||
_unwrap_data_handle(a), _unwrap_data_handle(b),
|
_unwrap_data_handle(a), _unwrap_data_handle(b), shape))
|
||||||
_unwrap_shape(shape)))
|
|
||||||
|
|
||||||
def While(self, cond, body, init):
|
def While(self, cond, body, init):
|
||||||
"""Enqueues a While operation onto the computation.
|
"""Enqueues a While operation onto the computation.
|
||||||
@ -865,10 +931,37 @@ class ComputationBuilder(object):
|
|||||||
_unwrap_data_handle(init)))
|
_unwrap_data_handle(init)))
|
||||||
|
|
||||||
def Dot(self, lhs, rhs):
|
def Dot(self, lhs, rhs):
|
||||||
"""Matrix multiplication between lhs and rhs."""
|
"""Enqueues a dot operation onto the computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array.
|
||||||
|
rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array.
|
||||||
|
|
||||||
|
Returns: a ComputationDataHandle representing the Dot operation.
|
||||||
|
"""
|
||||||
return _wrap_data_handle(
|
return _wrap_data_handle(
|
||||||
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
|
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
|
||||||
|
|
||||||
|
def DotGeneral(self, lhs, rhs, dimension_numbers):
|
||||||
|
"""Enqueues a general dot operation onto the computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lhs: ComputationDataHandle for the left-hand-side array.
|
||||||
|
rhs: ComputationDataHandle for the right-hand-side array.
|
||||||
|
dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested
|
||||||
|
tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of
|
||||||
|
integers representing the dimensions to treat as contracting dimensions
|
||||||
|
and batch dimensions on each input operand.
|
||||||
|
|
||||||
|
Returns: a ComputationDataHandle representing the DotGeneral operation.
|
||||||
|
"""
|
||||||
|
if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers):
|
||||||
|
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
|
||||||
|
return _wrap_data_handle(
|
||||||
|
self._client.DotGeneral(
|
||||||
|
_unwrap_data_handle(lhs), _unwrap_data_handle(rhs),
|
||||||
|
dimension_numbers))
|
||||||
|
|
||||||
def Conv(self, lhs, rhs, window_strides, padding):
|
def Conv(self, lhs, rhs, window_strides, padding):
|
||||||
"""Enqueues a Conv operation onto the computation.
|
"""Enqueues a Conv operation onto the computation.
|
||||||
|
|
||||||
@ -979,7 +1072,7 @@ def initialize_replica_count(replica_count):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
replica_count: number of replicas that are desired for set up during XLA
|
replica_count: number of replicas that are desired for set up during XLA
|
||||||
initalization.
|
initialization.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
A runtime exception if the XLA service has already been initialized.
|
A runtime exception if the XLA service has already been initialized.
|
||||||
@ -1005,3 +1098,13 @@ def GetPaddingConfigFromTriples(triples):
|
|||||||
dimension.edge_padding_high = hi
|
dimension.edge_padding_high = hi
|
||||||
dimension.interior_padding = interior
|
dimension.interior_padding = interior
|
||||||
return padding_config
|
return padding_config
|
||||||
|
|
||||||
|
|
||||||
|
def GetDotDimensionsFromLists(dimension_numbers):
|
||||||
|
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||||
|
dot_dims_proto = xla_data_pb2.DotDimensionNumbers()
|
||||||
|
dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
|
||||||
|
dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
|
||||||
|
dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
|
||||||
|
dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
|
||||||
|
return dot_dims_proto
|
||||||
|
@ -444,6 +444,30 @@ class SingleOpTest(LocalComputationTest):
|
|||||||
c.Dot(c.Constant(lhs), c.Constant(rhs))
|
c.Dot(c.Constant(lhs), c.Constant(rhs))
|
||||||
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
|
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
|
||||||
|
|
||||||
|
def testDotGeneral(self):
|
||||||
|
c = self._NewComputation()
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
lhs = NumpyArrayF32(rng.randn(10, 3, 4))
|
||||||
|
rhs = NumpyArrayF32(rng.randn(10, 4, 5))
|
||||||
|
dimension_numbers = (([2], [1]), ([0], [0]))
|
||||||
|
c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
|
||||||
|
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
|
||||||
|
|
||||||
|
def testDotGeneralWithDotDimensionNumbersProto(self):
|
||||||
|
c = self._NewComputation()
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
lhs = NumpyArrayF32(rng.randn(10, 3, 4))
|
||||||
|
rhs = NumpyArrayF32(rng.randn(10, 4, 5))
|
||||||
|
|
||||||
|
dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers()
|
||||||
|
dimension_numbers.lhs_contracting_dimensions.append(2)
|
||||||
|
dimension_numbers.rhs_contracting_dimensions.append(1)
|
||||||
|
dimension_numbers.lhs_batch_dimensions.append(0)
|
||||||
|
dimension_numbers.rhs_batch_dimensions.append(0)
|
||||||
|
|
||||||
|
c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
|
||||||
|
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
|
||||||
|
|
||||||
def testConvF32Same(self):
|
def testConvF32Same(self):
|
||||||
c = self._NewComputation()
|
c = self._NewComputation()
|
||||||
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
|
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
|
||||||
@ -496,6 +520,12 @@ class SingleOpTest(LocalComputationTest):
|
|||||||
c.Exp(c.Constant(arr))
|
c.Exp(c.Constant(arr))
|
||||||
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
|
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
|
||||||
|
|
||||||
|
def testRound(self):
|
||||||
|
c = self._NewComputation()
|
||||||
|
arr = NumpyArrayF32([3.3, 12.1])
|
||||||
|
c.Round(c.Constant(arr))
|
||||||
|
self._ExecuteAndCompareClose(c, expected=np.round(arr))
|
||||||
|
|
||||||
def testLog(self):
|
def testLog(self):
|
||||||
c = self._NewComputation()
|
c = self._NewComputation()
|
||||||
arr = NumpyArrayF32([3.3, 12.1])
|
arr = NumpyArrayF32([3.3, 12.1])
|
||||||
@ -699,6 +729,23 @@ class SingleOpTest(LocalComputationTest):
|
|||||||
self._ExecuteAndCompareExact(
|
self._ExecuteAndCompareExact(
|
||||||
c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]])
|
c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]])
|
||||||
|
|
||||||
|
def testClampF32(self):
|
||||||
|
c = self._NewComputation()
|
||||||
|
c.Clamp(
|
||||||
|
c.Constant(NumpyArrayF32(-1)),
|
||||||
|
c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
|
||||||
|
c.Constant(NumpyArrayF32(2)))
|
||||||
|
self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
|
||||||
|
|
||||||
|
# TODO(b/72689392): re-enable when bug S32 resolved
|
||||||
|
def DISABLED_testClampS32(self):
|
||||||
|
c = self._NewComputation()
|
||||||
|
c.Clamp(
|
||||||
|
c.Constant(NumpyArrayS32(-1)),
|
||||||
|
c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
|
||||||
|
c.Constant(NumpyArrayS32(2)))
|
||||||
|
self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2])
|
||||||
|
|
||||||
def testSelect(self):
|
def testSelect(self):
|
||||||
c = self._NewComputation()
|
c = self._NewComputation()
|
||||||
c.Select(
|
c.Select(
|
||||||
|
@ -509,6 +509,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
],
|
],
|
||||||
@ -1110,8 +1111,6 @@ cc_library(
|
|||||||
":hlo",
|
":hlo",
|
||||||
":hlo_evaluator",
|
":hlo_evaluator",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
":tuple_util",
|
|
||||||
":while_util",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
@ -1156,6 +1155,34 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "implicit_broadcast_remover",
|
||||||
|
srcs = ["implicit_broadcast_remover.cc"],
|
||||||
|
hdrs = ["implicit_broadcast_remover.h"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
":hlo_dce",
|
||||||
|
":hlo_pass",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "implicit_broadcast_remover_test",
|
||||||
|
srcs = ["implicit_broadcast_remover_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":hlo_matchers",
|
||||||
|
":implicit_broadcast_remover",
|
||||||
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "dot_decomposer",
|
name = "dot_decomposer",
|
||||||
srcs = ["dot_decomposer.cc"],
|
srcs = ["dot_decomposer.cc"],
|
||||||
@ -1825,7 +1852,9 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:test_utils",
|
"//tensorflow/compiler/xla/tests:test_utils",
|
||||||
|
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1856,6 +1885,7 @@ cc_library(
|
|||||||
":hlo",
|
":hlo",
|
||||||
":hlo_graph_dumper",
|
":hlo_graph_dumper",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
|
":hlo_proto_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
|
@ -1618,9 +1618,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
|||||||
reduce,
|
reduce,
|
||||||
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
|
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Transpose feeding a reduce can simply permute the reduction dimensions
|
// A Transpose feeding a reduce can simply permute the reduction dimensions
|
||||||
// field.
|
// field if the output of the reduce is a vector or scalar. Higher ranked
|
||||||
if (arg->opcode() == HloOpcode::kTranspose) {
|
// result may require a transpose of the output.
|
||||||
|
if (ShapeUtil::Rank(reduce->shape()) <= 1 &&
|
||||||
|
arg->opcode() == HloOpcode::kTranspose) {
|
||||||
auto transpose_dimensions = arg->dimensions();
|
auto transpose_dimensions = arg->dimensions();
|
||||||
std::vector<int64> new_reduce_dimensions;
|
std::vector<int64> new_reduce_dimensions;
|
||||||
for (auto dim : dimensions) {
|
for (auto dim : dimensions) {
|
||||||
|
@ -997,14 +997,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
|||||||
auto color = single_colored_set.first;
|
auto color = single_colored_set.first;
|
||||||
VLOG(2) << "Simulating heap for color " << color;
|
VLOG(2) << "Simulating heap for color " << color;
|
||||||
int64 alignment = assignment->color_alignment_(color);
|
int64 alignment = assignment->color_alignment_(color);
|
||||||
|
HeapSimulator::Options options;
|
||||||
|
options.buffers_to_assign = &single_colored_set.second;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
const HeapSimulator::Result result,
|
const HeapSimulator::Result result,
|
||||||
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
|
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
|
||||||
MakeUnique<LazyBestFitHeap>(alignment)),
|
MakeUnique<LazyBestFitHeap>(alignment)),
|
||||||
assignment->module(), module_sequence,
|
assignment->module(), module_sequence,
|
||||||
assignment->points_to_analysis(),
|
assignment->points_to_analysis(),
|
||||||
assignment->buffer_size_,
|
assignment->buffer_size_, options));
|
||||||
&single_colored_set.second));
|
|
||||||
AssignBuffersFromHeapSimulator(result, assignment,
|
AssignBuffersFromHeapSimulator(result, assignment,
|
||||||
single_colored_set.first);
|
single_colored_set.first);
|
||||||
}
|
}
|
||||||
@ -1024,14 +1025,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
|||||||
auto color = single_colored_set.first;
|
auto color = single_colored_set.first;
|
||||||
VLOG(2) << "Simulating heap for color " << color;
|
VLOG(2) << "Simulating heap for color " << color;
|
||||||
int64 alignment = assignment->color_alignment_(color);
|
int64 alignment = assignment->color_alignment_(color);
|
||||||
|
HeapSimulator::Options options;
|
||||||
|
options.buffers_to_assign = &single_colored_set.second;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
const HeapSimulator::Result result,
|
const HeapSimulator::Result result,
|
||||||
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
|
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
|
||||||
MakeUnique<LazyBestFitHeap>(alignment)),
|
MakeUnique<LazyBestFitHeap>(alignment)),
|
||||||
*computation, *instruction_sequence,
|
*computation, *instruction_sequence,
|
||||||
assignment->points_to_analysis(),
|
assignment->points_to_analysis(),
|
||||||
assignment->buffer_size_,
|
assignment->buffer_size_, options));
|
||||||
&single_colored_set.second));
|
|
||||||
AssignBuffersFromHeapSimulator(result, assignment,
|
AssignBuffersFromHeapSimulator(result, assignment,
|
||||||
single_colored_set.first);
|
single_colored_set.first);
|
||||||
}
|
}
|
||||||
|
@ -72,8 +72,18 @@ class AotCompilationOptions {
|
|||||||
// Returns the ID of the platform to which these options apply.
|
// Returns the ID of the platform to which these options apply.
|
||||||
virtual perftools::gputools::Platform::Id PlatformId() const = 0;
|
virtual perftools::gputools::Platform::Id PlatformId() const = 0;
|
||||||
|
|
||||||
|
// Optional allocator that may be used for allocating temp space on the device
|
||||||
|
// during compilation.
|
||||||
|
DeviceMemoryAllocator* device_allocator() const { return device_allocator_; }
|
||||||
|
void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
|
||||||
|
device_allocator_ = device_allocator;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
AotCompilationOptions() = default;
|
AotCompilationOptions() = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Abstract compiler interface that is subclassed for compilation on a
|
// Abstract compiler interface that is subclassed for compilation on a
|
||||||
@ -99,9 +109,16 @@ class Compiler {
|
|||||||
|
|
||||||
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
|
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
|
||||||
// module.
|
// module.
|
||||||
|
//
|
||||||
|
// If device_allocator is not null, the compiler may use it to allocate temp
|
||||||
|
// space on the device for use during compilation. For example, the compiler
|
||||||
|
// may allocate buffers on the device and then run variants of a given
|
||||||
|
// algorithm over those buffers, to see which variant is fastest. Any space
|
||||||
|
// allocated should be deallocated before this function returns.
|
||||||
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
perftools::gputools::StreamExecutor* executor) = 0;
|
perftools::gputools::StreamExecutor* executor,
|
||||||
|
DeviceMemoryAllocator* device_allocator) = 0;
|
||||||
|
|
||||||
// Compiles the HLO module for execution on a device given by the executor,
|
// Compiles the HLO module for execution on a device given by the executor,
|
||||||
// and returns an executable object or an error status. No HLO passes are
|
// and returns an executable object or an error status. No HLO passes are
|
||||||
@ -112,21 +129,27 @@ class Compiler {
|
|||||||
// The compiler may optionally specialize to the individual device
|
// The compiler may optionally specialize to the individual device
|
||||||
// (not just type of device) indicated by the executor.
|
// (not just type of device) indicated by the executor.
|
||||||
//
|
//
|
||||||
|
// device_allocator is optional; see RunHloPasses.
|
||||||
|
//
|
||||||
// Use the overload below to compile computations that run in parallel.
|
// Use the overload below to compile computations that run in parallel.
|
||||||
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
perftools::gputools::StreamExecutor* executor) = 0;
|
perftools::gputools::StreamExecutor* executor,
|
||||||
|
DeviceMemoryAllocator* device_allocator) = 0;
|
||||||
|
|
||||||
// Compiles a set of HLO modules that can run in parallel, potentially
|
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||||
// communicating data between the modules, and returns a corresponding
|
// communicating data between the modules, and returns a corresponding
|
||||||
// sequence of executable objects.
|
// sequence of executable objects.
|
||||||
//
|
//
|
||||||
|
// device_allocator is optional; see RunHloPasses.
|
||||||
|
//
|
||||||
// TODO(b/68666782): Remove this method after adding support for multiple
|
// TODO(b/68666782): Remove this method after adding support for multiple
|
||||||
// modules to RunHloPasses and RunBackends.
|
// modules to RunHloPasses and RunBackends.
|
||||||
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
std::vector<std::unique_ptr<HloModule>> modules,
|
std::vector<std::unique_ptr<HloModule>> modules,
|
||||||
std::vector<std::vector<perftools::gputools::StreamExecutor*>>
|
std::vector<std::vector<perftools::gputools::StreamExecutor*>>
|
||||||
stream_exec) = 0;
|
stream_exec,
|
||||||
|
DeviceMemoryAllocator* device_allocator) = 0;
|
||||||
|
|
||||||
// Compiles the HLO module for ahead-of-time execution. This is intended for
|
// Compiles the HLO module for ahead-of-time execution. This is intended for
|
||||||
// use in static compilation.
|
// use in static compilation.
|
||||||
|
@ -437,7 +437,8 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
perftools::gputools::StreamExecutor* /*stream_exec*/) {
|
perftools::gputools::StreamExecutor* /*stream_exec*/,
|
||||||
|
DeviceMemoryAllocator* /*device_allocator*/) {
|
||||||
VLOG(2) << "Before optimization:";
|
VLOG(2) << "Before optimization:";
|
||||||
XLA_VLOG_LINES(2, module->ToString());
|
XLA_VLOG_LINES(2, module->ToString());
|
||||||
|
|
||||||
@ -450,7 +451,8 @@ StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
perftools::gputools::StreamExecutor* stream_exec) {
|
perftools::gputools::StreamExecutor* stream_exec,
|
||||||
|
DeviceMemoryAllocator* /*device_allocator*/) {
|
||||||
const string timer_message =
|
const string timer_message =
|
||||||
"Compiling [" + module->name() + "] for CPU using JIT";
|
"Compiling [" + module->name() + "] for CPU using JIT";
|
||||||
XLA_SCOPED_LOGGING_TIMER(timer_message);
|
XLA_SCOPED_LOGGING_TIMER(timer_message);
|
||||||
@ -517,8 +519,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
// ownership is std::moved.
|
// ownership is std::moved.
|
||||||
const bool embed_ir_in_executable =
|
const bool embed_ir_in_executable =
|
||||||
module->config().debug_options().xla_embed_ir_in_executable();
|
module->config().debug_options().xla_embed_ir_in_executable();
|
||||||
const string xla_dump_hlo_proto_to =
|
const string xla_dump_optimized_hlo_proto_to =
|
||||||
module->config().debug_options().xla_dump_hlo_proto_to();
|
module->config().debug_options().xla_dump_optimized_hlo_proto_to();
|
||||||
|
|
||||||
if (options::CpuParallelBackendRequested(module->config())) {
|
if (options::CpuParallelBackendRequested(module->config())) {
|
||||||
VLOG(1) << "Using parallel cpu backend";
|
VLOG(1) << "Using parallel cpu backend";
|
||||||
@ -538,10 +540,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
// print one ourselves.
|
// print one ourselves.
|
||||||
XLA_VLOG_LINES(2, assignment->ToString());
|
XLA_VLOG_LINES(2, assignment->ToString());
|
||||||
|
|
||||||
if (!xla_dump_hlo_proto_to.empty()) {
|
if (!xla_dump_optimized_hlo_proto_to.empty()) {
|
||||||
HloProto proto = MakeHloProto(*module, *assignment);
|
HloProto proto = MakeHloProto(*module, *assignment);
|
||||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||||
proto, xla_dump_hlo_proto_to, module->name()));
|
proto, xla_dump_optimized_hlo_proto_to, module->name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we are using the parallel CPU backend, we need to create map from
|
// If we are using the parallel CPU backend, we need to create map from
|
||||||
@ -647,10 +649,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
// print one ourselves.
|
// print one ourselves.
|
||||||
XLA_VLOG_LINES(2, assignment->ToString());
|
XLA_VLOG_LINES(2, assignment->ToString());
|
||||||
|
|
||||||
if (!xla_dump_hlo_proto_to.empty()) {
|
if (!xla_dump_optimized_hlo_proto_to.empty()) {
|
||||||
HloProto proto = MakeHloProto(*module, *assignment);
|
HloProto proto = MakeHloProto(*module, *assignment);
|
||||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||||
proto, xla_dump_hlo_proto_to, module->name()));
|
proto, xla_dump_optimized_hlo_proto_to, module->name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each computation is a single function. Emit all embedded computations
|
// Each computation is a single function. Emit all embedded computations
|
||||||
@ -826,12 +828,12 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
|||||||
// print one ourselves.
|
// print one ourselves.
|
||||||
XLA_VLOG_LINES(2, assignment->ToString());
|
XLA_VLOG_LINES(2, assignment->ToString());
|
||||||
|
|
||||||
const string xla_dump_hlo_proto_to =
|
const string xla_dump_optimized_hlo_proto_to =
|
||||||
module->config().debug_options().xla_dump_hlo_proto_to();
|
module->config().debug_options().xla_dump_optimized_hlo_proto_to();
|
||||||
if (!xla_dump_hlo_proto_to.empty()) {
|
if (!xla_dump_optimized_hlo_proto_to.empty()) {
|
||||||
HloProto proto = MakeHloProto(*module, *assignment);
|
HloProto proto = MakeHloProto(*module, *assignment);
|
||||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||||
proto, xla_dump_hlo_proto_to, module->name()));
|
proto, xla_dump_optimized_hlo_proto_to, module->name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
|
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
|
||||||
|
@ -118,11 +118,13 @@ class CpuCompiler : public LLVMCompiler {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
perftools::gputools::StreamExecutor* stream_exec,
|
||||||
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
perftools::gputools::StreamExecutor* stream_exec,
|
||||||
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||||
|
@ -479,7 +479,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
|||||||
|
|
||||||
Status IrEmitter::HandleSort(HloInstruction* sort) {
|
Status IrEmitter::HandleSort(HloInstruction* sort) {
|
||||||
// TODO(b/26783907): Implement sort on CPU.
|
// TODO(b/26783907): Implement sort on CPU.
|
||||||
return Unimplemented("Sort is not supported on CPU (b/26783907).");
|
return Unimplemented("Sort is not implemented on CPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
||||||
@ -522,7 +522,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
|
|||||||
// TODO(b/31410564): Implement dilation for reduce-window.
|
// TODO(b/31410564): Implement dilation for reduce-window.
|
||||||
if (window_util::HasDilation(window)) {
|
if (window_util::HasDilation(window)) {
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"Dilation for reduce-window not implemented on CPU. See b/31410564.");
|
"Dilation for ReduceWindow is not implemented on CPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// The called computation should have been emitted previously.
|
// The called computation should have been emitted previously.
|
||||||
@ -625,8 +625,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
|||||||
// TODO(b/31410564): Implement dilation for select-and-scatter.
|
// TODO(b/31410564): Implement dilation for select-and-scatter.
|
||||||
if (window_util::HasDilation(window)) {
|
if (window_util::HasDilation(window)) {
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"Dilation for select-and-scatter not implemented on CPU. "
|
"Dilation for SelectAndScatter is not implemented on CPU. ");
|
||||||
"See b/31410564.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The select and scatter computations should have been emitted previously.
|
// The select and scatter computations should have been emitted previously.
|
||||||
@ -1196,8 +1195,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/33011107): Support cross replica sum on CPU.
|
// TODO(b/33011107): Support cross replica sum on CPU.
|
||||||
return Unimplemented(
|
return Unimplemented("CrossReplicaSum is not implemented on CPU.");
|
||||||
"Cross replica sum is not implemented on CPU. See b/33011107.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fills up the free variables in 'index_with_free_var' with values from
|
// Fills up the free variables in 'index_with_free_var' with values from
|
||||||
@ -1811,12 +1809,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
|||||||
|
|
||||||
Status IrEmitter::HandleSend(HloInstruction* send) {
|
Status IrEmitter::HandleSend(HloInstruction* send) {
|
||||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||||
return Unimplemented("Send is not implemented on CPU. See b/33942983.");
|
return Unimplemented("Send is not implemented on CPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
|
Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
|
||||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||||
return Unimplemented("Send-done is not implemented on CPU. See b/33942983.");
|
return Unimplemented("Send-done is not implemented on CPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleSlice(HloInstruction* slice) {
|
Status IrEmitter::HandleSlice(HloInstruction* slice) {
|
||||||
@ -1981,12 +1979,12 @@ Status IrEmitter::HandleDynamicUpdateSlice(
|
|||||||
|
|
||||||
Status IrEmitter::HandleRecv(HloInstruction* recv) {
|
Status IrEmitter::HandleRecv(HloInstruction* recv) {
|
||||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||||
return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
|
return Unimplemented("Recv is not implemented on CPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
|
Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
|
||||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||||
return Unimplemented("Recv-done is not implemented on CPU. See b/33942983.");
|
return Unimplemented("Recv-done is not implemented on CPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandlePad(HloInstruction* pad) {
|
Status IrEmitter::HandlePad(HloInstruction* pad) {
|
||||||
@ -1995,10 +1993,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
|
|||||||
for (auto& padding_dimension : pad->padding_config().dimensions()) {
|
for (auto& padding_dimension : pad->padding_config().dimensions()) {
|
||||||
if (padding_dimension.edge_padding_low() < 0 ||
|
if (padding_dimension.edge_padding_low() < 0 ||
|
||||||
padding_dimension.edge_padding_high() < 0) {
|
padding_dimension.edge_padding_high() < 0) {
|
||||||
return Unimplemented(
|
return InternalErrorStrCat(
|
||||||
"Negative padding not supported in the CPU backend (b/34628603); "
|
"Encountered negative padding in IrEmitter on CPU. "
|
||||||
"this should have been eliminated at the HLO level: %s",
|
"This should have been eliminated at the HLO level. ",
|
||||||
pad->padding_config().ShortDebugString().c_str());
|
pad->ToString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
|
StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
|
||||||
perftools::gputools::Platform* platform,
|
const perftools::gputools::Platform* platform,
|
||||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||||
stream_executors)
|
stream_executors)
|
||||||
: DeviceMemoryAllocator(platform),
|
: DeviceMemoryAllocator(platform),
|
||||||
|
@ -33,7 +33,7 @@ class DeviceMemoryAllocator {
|
|||||||
public:
|
public:
|
||||||
// Parameter platform indicates which platform the allocator allocates memory
|
// Parameter platform indicates which platform the allocator allocates memory
|
||||||
// on. Must be non-null.
|
// on. Must be non-null.
|
||||||
explicit DeviceMemoryAllocator(perftools::gputools::Platform* platform)
|
explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform)
|
||||||
: platform_(platform) {}
|
: platform_(platform) {}
|
||||||
virtual ~DeviceMemoryAllocator() {}
|
virtual ~DeviceMemoryAllocator() {}
|
||||||
|
|
||||||
@ -49,14 +49,14 @@ class DeviceMemoryAllocator {
|
|||||||
int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0;
|
int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0;
|
||||||
|
|
||||||
// Return the platform that the allocator allocates memory on.
|
// Return the platform that the allocator allocates memory on.
|
||||||
perftools::gputools::Platform* platform() const { return platform_; }
|
const perftools::gputools::Platform* platform() const { return platform_; }
|
||||||
|
|
||||||
// Can we call Deallocate() as soon as a computation has been scheduled on
|
// Can we call Deallocate() as soon as a computation has been scheduled on
|
||||||
// a stream, or do we have to wait for the computation to complete first?
|
// a stream, or do we have to wait for the computation to complete first?
|
||||||
virtual bool AllowsAsynchronousDeallocation() const = 0;
|
virtual bool AllowsAsynchronousDeallocation() const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
perftools::gputools::Platform* platform_;
|
const perftools::gputools::Platform* platform_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Default memory allocator for a platform which uses
|
// Default memory allocator for a platform which uses
|
||||||
@ -64,7 +64,7 @@ class DeviceMemoryAllocator {
|
|||||||
class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
|
class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
|
||||||
public:
|
public:
|
||||||
StreamExecutorMemoryAllocator(
|
StreamExecutorMemoryAllocator(
|
||||||
perftools::gputools::Platform* platform,
|
const perftools::gputools::Platform* platform,
|
||||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||||
stream_executors);
|
stream_executors);
|
||||||
|
|
||||||
|
@ -428,7 +428,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
|||||||
llvm::Intrinsic::round, {operand_value}, {operand_value->getType()},
|
llvm::Intrinsic::round, {operand_value}, {operand_value->getType()},
|
||||||
ir_builder_);
|
ir_builder_);
|
||||||
case HloOpcode::kSign: {
|
case HloOpcode::kSign: {
|
||||||
// TODO(b/32151903): Ensure consistent sign behavior for -0.0
|
// TODO(b/32151903): Ensure consistent sign behavior for -0.0.
|
||||||
auto type = operand_value->getType();
|
auto type = operand_value->getType();
|
||||||
auto zero = llvm::ConstantFP::get(type, 0.0);
|
auto zero = llvm::ConstantFP::get(type, 0.0);
|
||||||
auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
|
auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
|
||||||
@ -870,7 +870,10 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
|
|||||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
|
||||||
llvm::Value* x) const {
|
llvm::Value* x) const {
|
||||||
if (prim_type != F32) {
|
if (prim_type != F32) {
|
||||||
return Unimplemented("inverse erf only implemented for F32 (b/34339814)");
|
// TODO(b/34339814): Implement inverse erf for F64.
|
||||||
|
return Unimplemented(
|
||||||
|
"Inverse erf is only implemented for element "
|
||||||
|
"type F32.");
|
||||||
}
|
}
|
||||||
auto getFloat = [&](const float f) {
|
auto getFloat = [&](const float f) {
|
||||||
return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
|
return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
|
||||||
@ -1040,17 +1043,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
|||||||
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
|
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
|
||||||
lhs_value, rhs_value, ir_builder_);
|
lhs_value, rhs_value, ir_builder_);
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
return ir_builder_->CreateSelect(
|
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
|
||||||
ir_builder_->CreateICmp(
|
|
||||||
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
|
|
||||||
lhs_value, rhs_value),
|
|
||||||
lhs_value, rhs_value);
|
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
return ir_builder_->CreateSelect(
|
return EmitIntegralMax(lhs_value, rhs_value, is_signed);
|
||||||
ir_builder_->CreateICmp(
|
|
||||||
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
|
|
||||||
lhs_value, rhs_value),
|
|
||||||
lhs_value, rhs_value);
|
|
||||||
case HloOpcode::kAnd:
|
case HloOpcode::kAnd:
|
||||||
return ir_builder_->CreateAnd(lhs_value, rhs_value);
|
return ir_builder_->CreateAnd(lhs_value, rhs_value);
|
||||||
case HloOpcode::kOr:
|
case HloOpcode::kOr:
|
||||||
@ -1067,6 +1062,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
|
||||||
|
llvm::Value* rhs_value,
|
||||||
|
bool is_signed) const {
|
||||||
|
return ir_builder_->CreateSelect(
|
||||||
|
ir_builder_->CreateICmp(
|
||||||
|
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
|
||||||
|
lhs_value, rhs_value),
|
||||||
|
lhs_value, rhs_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
|
||||||
|
llvm::Value* rhs_value,
|
||||||
|
bool is_signed) const {
|
||||||
|
return ir_builder_->CreateSelect(
|
||||||
|
ir_builder_->CreateICmp(
|
||||||
|
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
|
||||||
|
lhs_value, rhs_value),
|
||||||
|
lhs_value, rhs_value);
|
||||||
|
}
|
||||||
|
|
||||||
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
|
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
|
||||||
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
|
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
|
||||||
int64 operand_no) const {
|
int64 operand_no) const {
|
||||||
@ -1363,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
|
TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
|
||||||
operand_to_generator.at(hlo->operand(2))(
|
operand_to_generator.at(hlo->operand(2))(
|
||||||
ElementwiseSourceIndex(index, *hlo, 2)));
|
ElementwiseSourceIndex(index, *hlo, 2)));
|
||||||
|
PrimitiveType prim_type = hlo->shape().element_type();
|
||||||
|
if (primitive_util::IsFloatingPointType(prim_type)) {
|
||||||
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
|
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
|
||||||
|
} else if (primitive_util::IsIntegralType(prim_type)) {
|
||||||
|
bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
|
||||||
|
return EmitIntegralMin(
|
||||||
|
max_value, EmitIntegralMax(min_value, arg_value, is_signed),
|
||||||
|
is_signed);
|
||||||
|
} else {
|
||||||
|
return Unimplemented("Clamp unimplemented for %s",
|
||||||
|
PrimitiveType_Name(prim_type).c_str());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
case HloOpcode::kReducePrecision:
|
case HloOpcode::kReducePrecision:
|
||||||
return [this, hlo, &operand_to_generator](
|
return [this, hlo, &operand_to_generator](
|
||||||
|
@ -86,6 +86,12 @@ class ElementalIrEmitter {
|
|||||||
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
|
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value) const;
|
llvm::Value* rhs_value) const;
|
||||||
|
|
||||||
|
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||||
|
bool is_signed) const;
|
||||||
|
|
||||||
|
llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||||
|
bool is_signed) const;
|
||||||
|
|
||||||
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
|
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
|
||||||
llvm::Value* value) const;
|
llvm::Value* value) const;
|
||||||
|
|
||||||
|
@ -131,6 +131,7 @@ cc_library(
|
|||||||
"ir_emitter_context.h",
|
"ir_emitter_context.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":cudnn_convolution_runner",
|
||||||
":elemental_ir_emitter",
|
":elemental_ir_emitter",
|
||||||
":gpu_constants",
|
":gpu_constants",
|
||||||
":gpu_executable",
|
":gpu_executable",
|
||||||
@ -262,6 +263,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":buffer_allocations",
|
":buffer_allocations",
|
||||||
|
":cudnn_convolution_runner",
|
||||||
":infeed_manager",
|
":infeed_manager",
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
":partition_assignment",
|
":partition_assignment",
|
||||||
@ -309,9 +311,41 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "convolution_folding",
|
name = "cudnn_convolution_algorithm_picker",
|
||||||
srcs = ["convolution_folding.cc"],
|
srcs = ["cudnn_convolution_algorithm_picker.cc"],
|
||||||
hdrs = ["convolution_folding.h"],
|
hdrs = ["cudnn_convolution_algorithm_picker.h"],
|
||||||
|
deps = [
|
||||||
|
":cudnn_convolution_runner",
|
||||||
|
":gpu_executable",
|
||||||
|
":ir_emission_utils",
|
||||||
|
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "cudnn_convolution_runner",
|
||||||
|
srcs = ["cudnn_convolution_runner.cc"],
|
||||||
|
hdrs = ["cudnn_convolution_runner.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "cudnn_convolution_rewriter",
|
||||||
|
srcs = ["cudnn_convolution_rewriter.cc"],
|
||||||
|
hdrs = ["cudnn_convolution_rewriter.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
@ -325,15 +359,18 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "convolution_folding_test",
|
name = "cudnn_convolution_rewriter_test",
|
||||||
srcs = ["convolution_folding_test.cc"],
|
srcs = ["cudnn_convolution_rewriter_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":convolution_folding",
|
":cudnn_convolution_rewriter",
|
||||||
|
":ir_emission_utils",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||||
"//tensorflow/compiler/xla/service:shape_inference",
|
"//tensorflow/compiler/xla/service:shape_inference",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -446,7 +483,8 @@ cc_library(
|
|||||||
srcs = ["gpu_compiler.cc"],
|
srcs = ["gpu_compiler.cc"],
|
||||||
hdrs = ["gpu_compiler.h"],
|
hdrs = ["gpu_compiler.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":convolution_folding",
|
":cudnn_convolution_algorithm_picker",
|
||||||
|
":cudnn_convolution_rewriter",
|
||||||
":fusion_merger",
|
":fusion_merger",
|
||||||
":gpu_constants",
|
":gpu_constants",
|
||||||
":gpu_copy_insertion",
|
":gpu_copy_insertion",
|
||||||
@ -514,7 +552,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||||
"@llvm//:core",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user